mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 18:44:07 +08:00
Merge branch 'master' into py39
This commit is contained in:
+2
-46
@@ -20,50 +20,6 @@ before_install:
|
||||
|
||||
matrix:
|
||||
include:
|
||||
- os: linux
|
||||
env:
|
||||
- PYTHON=3.6 SMALL_AND_LARGE_TESTS=1 RAY_ENABLE_NEW_SCHEDULER=1
|
||||
- PYTHONWARNINGS=ignore
|
||||
- RAY_DEFAULT_BUILD=1
|
||||
- RAY_CYTHON_EXAMPLES=1
|
||||
- RAY_USE_RANDOM_PORTS=1
|
||||
install:
|
||||
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED
|
||||
before_script:
|
||||
- . ./ci/travis/ci.sh build
|
||||
script:
|
||||
# bazel python tests. This should be run last to keep its logs at the end of travis logs.
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-new_scheduler_broken python/ray/tests/...; fi
|
||||
|
||||
- os: linux
|
||||
env:
|
||||
- PYTHON=3.6 MEDIUM_TESTS_A_TO_J=1 RAY_ENABLE_NEW_SCHEDULER=1
|
||||
- PYTHONWARNINGS=ignore
|
||||
- RAY_DEFAULT_BUILD=1
|
||||
- RAY_CYTHON_EXAMPLES=1
|
||||
- RAY_USE_RANDOM_PORTS=1
|
||||
install:
|
||||
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED
|
||||
before_script:
|
||||
- . ./ci/travis/ci.sh build
|
||||
script:
|
||||
# bazel python tests for medium size tests. Used for parallelization.
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j,-new_scheduler_broken python/ray/tests/...; fi
|
||||
|
||||
- os: linux
|
||||
env:
|
||||
- PYTHON=3.6 MEDIUM_TESTS_K_TO_Z=1 RAY_ENABLE_NEW_SCHEDULER=1
|
||||
- PYTHONWARNINGS=ignore
|
||||
- RAY_DEFAULT_BUILD=1
|
||||
- RAY_CYTHON_EXAMPLES=1
|
||||
- RAY_USE_RANDOM_PORTS=1
|
||||
install:
|
||||
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED
|
||||
before_script:
|
||||
- . ./ci/travis/ci.sh build
|
||||
script:
|
||||
# bazel python tests for medium size tests. Used for parallelization.
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_k_to_z,-new_scheduler_broken python/ray/tests/...; fi
|
||||
- os: linux
|
||||
env:
|
||||
- PYTHON=3.6 SMALL_AND_LARGE_TESTS=1
|
||||
@@ -90,6 +46,7 @@ matrix:
|
||||
script:
|
||||
# bazel python tests for medium size tests. Used for parallelization.
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi
|
||||
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_CLIENT_MODE=1 python/ray/tests/...; fi
|
||||
|
||||
- os: linux
|
||||
env:
|
||||
@@ -219,7 +176,7 @@ matrix:
|
||||
- . ./ci/travis/ci.sh init RAY_CI_MACOS_WHEELS_AFFECTED,RAY_CI_JAVA_AFFECTED,RAY_CI_STREAMING_JAVA_AFFECTED
|
||||
before_script:
|
||||
- brew tap adoptopenjdk/openjdk
|
||||
- brew cask install adoptopenjdk8
|
||||
- brew install --cask adoptopenjdk8
|
||||
- export JAVA_HOME=/Library/Java/JavaVirtualMachines/adoptopenjdk-8.jdk/Contents/Home
|
||||
- java -version
|
||||
- . ./ci/travis/ci.sh build
|
||||
@@ -577,4 +534,3 @@ deploy:
|
||||
repo: ray-project/ray
|
||||
branch: master
|
||||
condition: $MULTIPLATFORM_JARS = 1 || $MAC_JARS = 1 || $LINUX_JARS = 1
|
||||
|
||||
|
||||
+3
-109
@@ -880,7 +880,7 @@ cc_test(
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "local_placement_group_manager_test",
|
||||
name = "placement_group_resource_manager_test",
|
||||
srcs = ["src/ray/raylet/placement_group_resource_manager_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
@@ -956,8 +956,8 @@ cc_test(
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "task_dependency_manager_test",
|
||||
srcs = ["src/ray/raylet/task_dependency_manager_test.cc"],
|
||||
name = "dependency_manager_test",
|
||||
srcs = ["src/ray/raylet/dependency_manager_test.cc"],
|
||||
copts = COPTS,
|
||||
deps = [
|
||||
":raylet_lib",
|
||||
@@ -1020,7 +1020,6 @@ cc_test(
|
||||
cc_library(
|
||||
name = "gcs_test_util_lib",
|
||||
hdrs = [
|
||||
"src/ray/gcs/test/accessor_test_base.h",
|
||||
"src/ray/gcs/test/gcs_test_util.h",
|
||||
],
|
||||
copts = COPTS,
|
||||
@@ -1621,111 +1620,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
# TODO(micafan) Support test group in future. Use test group we can run all gcs test once.
|
||||
cc_test(
|
||||
name = "redis_gcs_client_test",
|
||||
srcs = ["src/ray/gcs/test/redis_gcs_client_test.cc"],
|
||||
args = [
|
||||
"$(location redis-server)",
|
||||
"$(location redis-cli)",
|
||||
"$(location libray_redis_module.so)",
|
||||
],
|
||||
copts = COPTS,
|
||||
data = [
|
||||
"//:libray_redis_module.so",
|
||||
"//:redis-cli",
|
||||
"//:redis-server",
|
||||
],
|
||||
deps = [
|
||||
":gcs",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "redis_actor_info_accessor_test",
|
||||
srcs = ["src/ray/gcs/test/redis_actor_info_accessor_test.cc"],
|
||||
args = [
|
||||
"$(location redis-server)",
|
||||
"$(location redis-cli)",
|
||||
"$(location libray_redis_module.so)",
|
||||
],
|
||||
copts = COPTS,
|
||||
data = [
|
||||
"//:libray_redis_module.so",
|
||||
"//:redis-cli",
|
||||
"//:redis-server",
|
||||
],
|
||||
deps = [
|
||||
":gcs",
|
||||
":gcs_test_util_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "redis_object_info_accessor_test",
|
||||
srcs = ["src/ray/gcs/test/redis_object_info_accessor_test.cc"],
|
||||
args = [
|
||||
"$(location redis-server)",
|
||||
"$(location redis-cli)",
|
||||
"$(location libray_redis_module.so)",
|
||||
],
|
||||
copts = COPTS,
|
||||
data = [
|
||||
"//:libray_redis_module.so",
|
||||
"//:redis-cli",
|
||||
"//:redis-server",
|
||||
],
|
||||
deps = [
|
||||
":gcs",
|
||||
":gcs_test_util_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "redis_job_info_accessor_test",
|
||||
srcs = ["src/ray/gcs/test/redis_job_info_accessor_test.cc"],
|
||||
args = [
|
||||
"$(location redis-server)",
|
||||
"$(location redis-cli)",
|
||||
"$(location libray_redis_module.so)",
|
||||
],
|
||||
copts = COPTS,
|
||||
data = [
|
||||
"//:libray_redis_module.so",
|
||||
"//:redis-cli",
|
||||
"//:redis-server",
|
||||
],
|
||||
deps = [
|
||||
":gcs",
|
||||
":gcs_test_util_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "redis_node_info_accessor_test",
|
||||
srcs = ["src/ray/gcs/test/redis_node_info_accessor_test.cc"],
|
||||
args = [
|
||||
"$(location redis-server)",
|
||||
"$(location redis-cli)",
|
||||
"$(location libray_redis_module.so)",
|
||||
],
|
||||
copts = COPTS,
|
||||
data = [
|
||||
"//:libray_redis_module.so",
|
||||
"//:redis-cli",
|
||||
"//:redis-server",
|
||||
],
|
||||
deps = [
|
||||
":gcs",
|
||||
":gcs_test_util_lib",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "asio_test",
|
||||
srcs = ["src/ray/gcs/test/asio_test.cc"],
|
||||
|
||||
+4
-2
@@ -1,12 +1,14 @@
|
||||
# py_test_module_list creates a py_test target for each
|
||||
# Python file in `files`
|
||||
def py_test_module_list(files, size, deps, extra_srcs, **kwargs):
|
||||
def py_test_module_list(files, size, deps, extra_srcs, name_suffix="", **kwargs):
|
||||
for file in files:
|
||||
# remove .py
|
||||
name = file[:-3]
|
||||
name = file[:-3] + name_suffix
|
||||
main = file
|
||||
native.py_test(
|
||||
name = name,
|
||||
size = size,
|
||||
main = file,
|
||||
srcs = extra_srcs + [file],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
+9
-2
@@ -120,7 +120,6 @@ test_core() {
|
||||
case "${OSTYPE}" in
|
||||
msys)
|
||||
args+=(
|
||||
-//:redis_gcs_client_test
|
||||
-//:core_worker_test
|
||||
-//:event_test
|
||||
-//:gcs_pub_sub_test
|
||||
@@ -262,6 +261,11 @@ _bazel_build_before_install() {
|
||||
bazel build "${target}"
|
||||
}
|
||||
|
||||
|
||||
_bazel_build_protobuf() {
|
||||
bazel build "//:install_py_proto"
|
||||
}
|
||||
|
||||
install_ray() {
|
||||
# TODO(mehrdadn): This function should be unified with the one in python/build-wheel-windows.sh.
|
||||
(
|
||||
@@ -296,7 +300,8 @@ build_wheels() {
|
||||
;;
|
||||
darwin*)
|
||||
# This command should be kept in sync with ray/python/README-building-wheels.md.
|
||||
suppress_output "${WORKSPACE_DIR}"/python/build-wheel-macos.sh
|
||||
# Remove suppress_output for now to avoid timeout
|
||||
"${WORKSPACE_DIR}"/python/build-wheel-macos.sh
|
||||
;;
|
||||
msys*)
|
||||
keep_alive "${WORKSPACE_DIR}"/python/build-wheel-windows.sh
|
||||
@@ -457,6 +462,8 @@ init() {
|
||||
build() {
|
||||
if [ "${LINT-}" != 1 ]; then
|
||||
_bazel_build_before_install
|
||||
else
|
||||
_bazel_build_protobuf
|
||||
fi
|
||||
|
||||
if ! need_wheels; then
|
||||
|
||||
@@ -16,6 +16,7 @@ template <typename T>
|
||||
class ObjectRef {
|
||||
public:
|
||||
ObjectRef();
|
||||
~ObjectRef();
|
||||
|
||||
ObjectRef(const ObjectID &id);
|
||||
|
||||
@@ -46,6 +47,17 @@ ObjectRef<T>::ObjectRef() {}
|
||||
template <typename T>
|
||||
ObjectRef<T>::ObjectRef(const ObjectID &id) {
|
||||
id_ = id;
|
||||
if (CoreWorkerProcess::IsInitialized()) {
|
||||
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
|
||||
core_worker.AddLocalReference(id_);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
ObjectRef<T>::~ObjectRef() {
|
||||
if (CoreWorkerProcess::IsInitialized()) {
|
||||
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
|
||||
core_worker.RemoveLocalReference(id_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
||||
@@ -91,16 +91,13 @@ TEST(RayClusterModeTest, FullTest) {
|
||||
auto r5 = Ray::Task(Plus, r4, r3).Remote();
|
||||
auto r6 = Ray::Task(Plus, r4, 10).Remote();
|
||||
|
||||
///// TODO(ameer/guyang): All the commented code lines below should be
|
||||
///// uncommented once reference counting is added. Currently the objects
|
||||
///// are leaking from the object store.
|
||||
int result5 = *(Ray::Get(r5));
|
||||
// int result4 = *(Ray::Get(r4));
|
||||
int result4 = *(Ray::Get(r4));
|
||||
int result6 = *(Ray::Get(r6));
|
||||
// int result3 = *(Ray::Get(r3));
|
||||
int result3 = *(Ray::Get(r3));
|
||||
EXPECT_EQ(result0, 1);
|
||||
// EXPECT_EQ(result3, 1);
|
||||
// EXPECT_EQ(result4, 2);
|
||||
EXPECT_EQ(result3, 1);
|
||||
EXPECT_EQ(result4, 2);
|
||||
EXPECT_EQ(result5, 3);
|
||||
EXPECT_EQ(result6, 12);
|
||||
|
||||
@@ -114,37 +111,34 @@ TEST(RayClusterModeTest, FullTest) {
|
||||
int result7 = *(Ray::Get(r7));
|
||||
int result8 = *(Ray::Get(r8));
|
||||
int result9 = *(Ray::Get(r9));
|
||||
// int result10 = *(Ray::Get(r10));
|
||||
int result10 = *(Ray::Get(r10));
|
||||
EXPECT_EQ(result7, 15);
|
||||
EXPECT_EQ(result8, 16);
|
||||
EXPECT_EQ(result9, 19);
|
||||
// EXPECT_EQ(result10, 27);
|
||||
EXPECT_EQ(result10, 27);
|
||||
|
||||
/// create actor and task function remote call with args passed by reference
|
||||
// ActorHandle<Counter> actor5 = Ray::Actor(Counter::FactoryCreate, r10, 0).Remote();
|
||||
ActorHandle<Counter> actor5 = Ray::Actor(Counter::FactoryCreate, 27, 0).Remote();
|
||||
// auto r11 = actor5.Task(&Counter::Add, r0).Remote();
|
||||
auto r11 = actor5.Task(&Counter::Add, 1).Remote();
|
||||
// auto r12 = actor5.Task(&Counter::Add, r11).Remote();
|
||||
ActorHandle<Counter> actor5 = Ray::Actor(Counter::FactoryCreate, r10, 0).Remote();
|
||||
|
||||
auto r11 = actor5.Task(&Counter::Add, r0).Remote();
|
||||
auto r12 = actor5.Task(&Counter::Add, r11).Remote();
|
||||
auto r13 = actor5.Task(&Counter::Add, r10).Remote();
|
||||
auto r14 = actor5.Task(&Counter::Add, r13).Remote();
|
||||
// auto r15 = Ray::Task(Plus, r0, r11).Remote();
|
||||
auto r15 = Ray::Task(Plus, 1, r11).Remote();
|
||||
auto r15 = Ray::Task(Plus, r0, r11).Remote();
|
||||
auto r16 = Ray::Task(Plus1, r15).Remote();
|
||||
|
||||
// int result12 = *(Ray::Get(r12));
|
||||
int result12 = *(Ray::Get(r12));
|
||||
int result14 = *(Ray::Get(r14));
|
||||
// int result11 = *(Ray::Get(r11));
|
||||
// int result13 = *(Ray::Get(r13));
|
||||
int result11 = *(Ray::Get(r11));
|
||||
int result13 = *(Ray::Get(r13));
|
||||
int result16 = *(Ray::Get(r16));
|
||||
// int result15 = *(Ray::Get(r15));
|
||||
int result15 = *(Ray::Get(r15));
|
||||
|
||||
// EXPECT_EQ(result11, 28);
|
||||
// EXPECT_EQ(result12, 56);
|
||||
// EXPECT_EQ(result13, 83);
|
||||
// EXPECT_EQ(result14, 166);
|
||||
EXPECT_EQ(result14, 110);
|
||||
// EXPECT_EQ(result15, 29);
|
||||
EXPECT_EQ(result11, 28);
|
||||
EXPECT_EQ(result12, 56);
|
||||
EXPECT_EQ(result13, 83);
|
||||
EXPECT_EQ(result14, 166);
|
||||
EXPECT_EQ(result15, 29);
|
||||
EXPECT_EQ(result16, 30);
|
||||
|
||||
Ray::Shutdown();
|
||||
|
||||
@@ -77,7 +77,8 @@ class DataOrganizer:
|
||||
job_workers = {}
|
||||
node_workers = {}
|
||||
core_worker_stats = {}
|
||||
for node_id in DataSource.nodes.keys():
|
||||
# await inside for loop, so we create a copy of keys().
|
||||
for node_id in list(DataSource.nodes.keys()):
|
||||
workers = await cls.get_node_workers(node_id)
|
||||
for worker in workers:
|
||||
job_id = worker["jobId"]
|
||||
|
||||
@@ -4,7 +4,7 @@ import ray.utils
|
||||
import ray.new_dashboard.utils as dashboard_utils
|
||||
import ray.new_dashboard.actor_utils as actor_utils
|
||||
from ray.new_dashboard.utils import rest_response
|
||||
from ray.new_dashboard.datacenter import DataOrganizer
|
||||
from ray.new_dashboard.datacenter import DataOrganizer, DataSource
|
||||
from ray.core.generated import core_worker_pb2
|
||||
from ray.core.generated import core_worker_pb2_grpc
|
||||
|
||||
@@ -29,6 +29,14 @@ class LogicalViewHead(dashboard_utils.DashboardHeadModule):
|
||||
message="Fetched actor groups.",
|
||||
actor_groups=actor_groups)
|
||||
|
||||
@routes.get("/logical/actors")
|
||||
@dashboard_utils.aiohttp_cache
|
||||
async def get_all_actors(self, req) -> aiohttp.web.Response:
|
||||
return dashboard_utils.rest_response(
|
||||
success=True,
|
||||
message="All actors fetched.",
|
||||
actors=DataSource.actors)
|
||||
|
||||
@routes.get("/logical/kill_actor")
|
||||
async def kill_actor(self, req) -> aiohttp.web.Response:
|
||||
try:
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_actor_groups(ray_start_with_dashboard):
|
||||
assert wait_until_server_available(webui_url)
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
timeout_seconds = 5
|
||||
timeout_seconds = 10
|
||||
start_time = time.time()
|
||||
last_ex = None
|
||||
while True:
|
||||
@@ -79,6 +79,63 @@ def test_actor_groups(ray_start_with_dashboard):
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
||||
def test_actors(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
@ray.remote
|
||||
class Foo:
|
||||
def __init__(self, num):
|
||||
self.num = num
|
||||
|
||||
def do_task(self):
|
||||
return self.num
|
||||
|
||||
@ray.remote(num_gpus=1)
|
||||
class InfeasibleActor:
|
||||
pass
|
||||
|
||||
foo_actors = [Foo.remote(4), Foo.remote(5)]
|
||||
infeasible_actor = InfeasibleActor.remote() # noqa
|
||||
results = [actor.do_task.remote() for actor in foo_actors] # noqa
|
||||
webui_url = ray_start_with_dashboard["webui_url"]
|
||||
assert wait_until_server_available(webui_url)
|
||||
webui_url = format_web_url(webui_url)
|
||||
|
||||
timeout_seconds = 5
|
||||
start_time = time.time()
|
||||
last_ex = None
|
||||
while True:
|
||||
time.sleep(1)
|
||||
try:
|
||||
resp = requests.get(f"{webui_url}/logical/actors")
|
||||
resp_json = resp.json()
|
||||
resp_data = resp_json["data"]
|
||||
actors = resp_data["actors"]
|
||||
assert len(actors) == 3
|
||||
one_entry = list(actors.values())[0]
|
||||
assert "jobId" in one_entry
|
||||
assert "taskSpec" in one_entry
|
||||
assert "functionDescriptor" in one_entry["taskSpec"]
|
||||
assert type(one_entry["taskSpec"]["functionDescriptor"]) is dict
|
||||
assert "address" in one_entry
|
||||
assert type(one_entry["address"]) is dict
|
||||
assert "state" in one_entry
|
||||
assert "name" in one_entry
|
||||
assert "numRestarts" in one_entry
|
||||
assert "pid" in one_entry
|
||||
all_pids = [entry["pid"] for entry in actors.values()]
|
||||
assert 0 in all_pids # The infeasible actor
|
||||
assert len(all_pids) > 1
|
||||
break
|
||||
except Exception as ex:
|
||||
last_ex = ex
|
||||
finally:
|
||||
if time.time() > start_time + timeout_seconds:
|
||||
ex_stack = traceback.format_exception(
|
||||
type(last_ex), last_ex,
|
||||
last_ex.__traceback__) if last_ex else []
|
||||
ex_stack = "".join(ex_stack)
|
||||
raise Exception(f"Timed out while testing, {ex_stack}")
|
||||
|
||||
|
||||
def test_kill_actor(ray_start_with_dashboard):
|
||||
@ray.remote
|
||||
class Actor:
|
||||
|
||||
@@ -13,6 +13,7 @@ import ray.new_dashboard.utils as dashboard_utils
|
||||
import ray._private.services
|
||||
import ray.utils
|
||||
from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS,
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_ERROR)
|
||||
from ray.core.generated import reporter_pb2
|
||||
from ray.core.generated import reporter_pb2_grpc
|
||||
@@ -113,13 +114,20 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
|
||||
"""
|
||||
|
||||
aioredis_client = self._dashboard_head.aioredis_client
|
||||
status = await aioredis_client.hget(DEBUG_AUTOSCALING_STATUS, "value")
|
||||
legacy_status = await aioredis_client.hget(
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY, "value")
|
||||
formatted_status_string = await aioredis_client.hget(
|
||||
DEBUG_AUTOSCALING_STATUS, "value")
|
||||
formatted_status = json.loads(formatted_status_string.decode()
|
||||
) if formatted_status_string else {}
|
||||
error = await aioredis_client.hget(DEBUG_AUTOSCALING_ERROR, "value")
|
||||
return dashboard_utils.rest_response(
|
||||
success=True,
|
||||
message="Got cluster status.",
|
||||
autoscaling_status=status.decode() if status else None,
|
||||
autoscaling_status=legacy_status.decode()
|
||||
if legacy_status else None,
|
||||
autoscaling_error=error.decode() if error else None,
|
||||
cluster_status=formatted_status if formatted_status else None,
|
||||
)
|
||||
|
||||
async def run(self, server):
|
||||
|
||||
@@ -246,7 +246,8 @@ class StatsCollector(dashboard_utils.DashboardHeadModule):
|
||||
@async_loop_forever(
|
||||
stats_collector_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
|
||||
async def _update_node_stats(self):
|
||||
for node_id, stub in self._stubs.items():
|
||||
# Copy self._stubs to avoid `dictionary changed size during iteration`.
|
||||
for node_id, stub in list(self._stubs.items()):
|
||||
node_info = DataSource.nodes.get(node_id)
|
||||
if node_info["state"] != "ALIVE":
|
||||
continue
|
||||
|
||||
@@ -112,20 +112,16 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
def check_mem_table():
|
||||
resp = requests.get(f"{webui_url}/memory/memory_table")
|
||||
resp_data = resp.json()
|
||||
if not resp_data["result"]:
|
||||
return False
|
||||
assert resp_data["result"]
|
||||
latest_memory_table = resp_data["data"]["memoryTable"]
|
||||
summary = latest_memory_table["summary"]
|
||||
try:
|
||||
# 1 ref per handle and per object the actor has a ref to
|
||||
assert summary["totalActorHandles"] == len(actors) * 2
|
||||
# 1 ref for my_obj
|
||||
assert summary["totalLocalRefCount"] == 1
|
||||
return True
|
||||
except AssertionError:
|
||||
return False
|
||||
# 1 ref per handle and per object the actor has a ref to
|
||||
assert summary["totalActorHandles"] == len(actors) * 2
|
||||
# 1 ref for my_obj
|
||||
assert summary["totalLocalRefCount"] == 1
|
||||
|
||||
wait_for_condition(check_mem_table, 10)
|
||||
wait_until_succeeded_without_exception(
|
||||
check_mem_table, (AssertionError, ), timeout_ms=1000)
|
||||
|
||||
|
||||
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
|
||||
|
||||
@@ -19,7 +19,7 @@ from ray import ray_constants
|
||||
from ray.test_utils import (format_web_url, wait_for_condition,
|
||||
wait_until_server_available, run_string_as_driver,
|
||||
wait_until_succeeded_without_exception)
|
||||
from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS,
|
||||
from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS_LEGACY,
|
||||
DEBUG_AUTOSCALING_ERROR)
|
||||
import ray.new_dashboard.consts as dashboard_consts
|
||||
import ray.new_dashboard.utils as dashboard_utils
|
||||
@@ -458,11 +458,14 @@ def test_get_cluster_status(ray_start_with_dashboard):
|
||||
def get_cluster_status():
|
||||
response = requests.get(f"{webui_url}/api/cluster_status")
|
||||
response.raise_for_status()
|
||||
print(response.json())
|
||||
assert response.json()["result"]
|
||||
assert "autoscalingStatus" in response.json()["data"]
|
||||
assert response.json()["data"]["autoscalingStatus"] is None
|
||||
assert "autoscalingError" in response.json()["data"]
|
||||
assert response.json()["data"]["autoscalingError"] is None
|
||||
assert "clusterStatus" in response.json()["data"]
|
||||
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
|
||||
|
||||
wait_until_succeeded_without_exception(get_cluster_status,
|
||||
(requests.RequestException, ))
|
||||
@@ -478,7 +481,7 @@ def test_get_cluster_status(ray_start_with_dashboard):
|
||||
port=int(address[1]),
|
||||
password=ray_constants.REDIS_DEFAULT_PASSWORD)
|
||||
|
||||
client.hset(DEBUG_AUTOSCALING_STATUS, "value", "hello")
|
||||
client.hset(DEBUG_AUTOSCALING_STATUS_LEGACY, "value", "hello")
|
||||
client.hset(DEBUG_AUTOSCALING_ERROR, "value", "world")
|
||||
|
||||
response = requests.get(f"{webui_url}/api/cluster_status")
|
||||
@@ -488,6 +491,8 @@ def test_get_cluster_status(ray_start_with_dashboard):
|
||||
assert response.json()["data"]["autoscalingStatus"] == "hello"
|
||||
assert "autoscalingError" in response.json()["data"]
|
||||
assert response.json()["data"]["autoscalingError"] == "world"
|
||||
assert "clusterStatus" in response.json()["data"]
|
||||
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
|
||||
|
||||
|
||||
def test_immutable_types():
|
||||
|
||||
@@ -7,8 +7,9 @@ from ray.new_dashboard.memory_utils import (
|
||||
NODE_ADDRESS = "127.0.0.1"
|
||||
IS_DRIVER = True
|
||||
PID = 1
|
||||
OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA="
|
||||
ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000"
|
||||
|
||||
OBJECT_ID = "ZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZg=="
|
||||
ACTOR_ID = "fffffffffffffffffffffffffffffffff66d17ba010000c801000000"
|
||||
DECODED_ID = decode_object_ref_if_needed(OBJECT_ID)
|
||||
OBJECT_SIZE = 100
|
||||
|
||||
|
||||
@@ -392,3 +392,39 @@ To get information about the current available resource capacity of your cluster
|
||||
|
||||
.. autofunction:: ray.available_resources
|
||||
:noindex:
|
||||
|
||||
Object Spilling
|
||||
---------------
|
||||
|
||||
Ray 1.2.0+ has *beta* support for spilling objects to external storage once the capacity
|
||||
of the object store is used up. Please file a `GitHub issue <https://github.com/ray-project/ray/issues/>`__
|
||||
if you encounter any problems with this new feature. Eventually, object spilling will be
|
||||
enabled by default, but for now you need to enable it manually:
|
||||
|
||||
To enable object spilling to the local filesystem (single node clusters only):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ray.init(
|
||||
_system_config={
|
||||
"automatic_object_spilling_enabled": True,
|
||||
"object_spilling_config": json.dumps(
|
||||
{"type": "filesystem", "params": {"directory_path": "/tmp/spill"}},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
To enable object spilling to remote storage (any URI supported by `smart_open <https://pypi.org/project/smart-open/>`__):
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
ray.init(
|
||||
_system_config={
|
||||
"automatic_object_spilling_enabled": True,
|
||||
"max_io_workers": 4, # More IO workers for remote storage.
|
||||
"min_spilling_size": 100 * 1024 * 1024, # Spill at least 100MB at a time.
|
||||
"object_spilling_config": json.dumps(
|
||||
{"type": "smart_open", "params": {"uri": "s3:///bucket/path"}},
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@@ -41,6 +41,7 @@ MOCK_MODULES = [
|
||||
"horovod",
|
||||
"horovod.ray",
|
||||
"kubernetes",
|
||||
"mlflow",
|
||||
"mxnet",
|
||||
"mxnet.model",
|
||||
"psutil",
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 3.8 KiB |
@@ -89,6 +89,112 @@ The Ray debugger supports the
|
||||
`same commands as PDB
|
||||
<https://docs.python.org/3/library/pdb.html#debugger-commands>`_.
|
||||
|
||||
Stepping between Ray tasks
|
||||
--------------------------
|
||||
|
||||
You can use the debugger to step between Ray tasks. Let's take the
|
||||
following recursive function as an example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
@ray.remote
|
||||
def fact(n):
|
||||
if n == 1:
|
||||
return n
|
||||
else:
|
||||
n_id = fact.remote(n - 1)
|
||||
return n * ray.get(n_id)
|
||||
|
||||
ray.util.pdb.set_trace()
|
||||
result_ref = fact.remote(5)
|
||||
result = ray.get(result_ref)
|
||||
|
||||
|
||||
After running the program by executing the Python file and calling
|
||||
``ray debug``, you can select the breakpoint by pressing ``0`` and
|
||||
enter. This will result in the following output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Enter breakpoint index or press enter to refresh: 0
|
||||
> /Users/pcmoritz/tmp/stepping.py(14)<module>()
|
||||
-> result_ref = fact.remote(5)
|
||||
(Pdb)
|
||||
|
||||
You can jump into the call with the ``remote`` command in Ray's debugger.
|
||||
Inside the function, print the value of `n` with ``p(n)``, resulting in
|
||||
the following output:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
-> result_ref = fact.remote(5)
|
||||
(Pdb) remote
|
||||
*** Connection closed by remote host ***
|
||||
Continuing pdb session in different process...
|
||||
--Call--
|
||||
> /Users/pcmoritz/tmp/stepping.py(5)fact()
|
||||
-> @ray.remote
|
||||
(Pdb) ll
|
||||
5 -> @ray.remote
|
||||
6 def fact(n):
|
||||
7 if n == 1:
|
||||
8 return n
|
||||
9 else:
|
||||
10 n_id = fact.remote(n - 1)
|
||||
11 return n * ray.get(n_id)
|
||||
(Pdb) p(n)
|
||||
5
|
||||
(Pdb)
|
||||
|
||||
Now step into the next remote call again with
|
||||
``remote`` and print `n`. You an now either continue recursing into
|
||||
the function by calling ``remote`` a few more times, or you can jump
|
||||
to the location where ``ray.get`` is called on the result by using the
|
||||
``get`` debugger comand. Use ``get`` again to jump back to the original
|
||||
call site and use ``p(result)`` to print the result:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
Enter breakpoint index or press enter to refresh: 0
|
||||
> /Users/pcmoritz/tmp/stepping.py(14)<module>()
|
||||
-> result_ref = fact.remote(5)
|
||||
(Pdb) remote
|
||||
*** Connection closed by remote host ***
|
||||
Continuing pdb session in different process...
|
||||
--Call--
|
||||
> /Users/pcmoritz/tmp/stepping.py(5)fact()
|
||||
-> @ray.remote
|
||||
(Pdb) p(n)
|
||||
5
|
||||
(Pdb) remote
|
||||
*** Connection closed by remote host ***
|
||||
Continuing pdb session in different process...
|
||||
--Call--
|
||||
> /Users/pcmoritz/tmp/stepping.py(5)fact()
|
||||
-> @ray.remote
|
||||
(Pdb) p(n)
|
||||
4
|
||||
(Pdb) get
|
||||
*** Connection closed by remote host ***
|
||||
Continuing pdb session in different process...
|
||||
--Return--
|
||||
> /Users/pcmoritz/tmp/stepping.py(5)fact()->120
|
||||
-> @ray.remote
|
||||
(Pdb) get
|
||||
*** Connection closed by remote host ***
|
||||
Continuing pdb session in different process...
|
||||
--Return--
|
||||
> /Users/pcmoritz/tmp/stepping.py(14)<module>()->None
|
||||
-> result_ref = fact.remote(5)
|
||||
(Pdb) p(result)
|
||||
120
|
||||
(Pdb)
|
||||
|
||||
|
||||
Post Mortem Debugging
|
||||
---------------------
|
||||
|
||||
|
||||
@@ -267,6 +267,26 @@ That's it. Let's take a look at an example:
|
||||
|
||||
.. literalinclude:: ../../../python/ray/serve/examples/doc/snippet_model_composition.py
|
||||
|
||||
|
||||
.. _serve-sync-async-handles:
|
||||
|
||||
Sync and Async Handles
|
||||
======================
|
||||
|
||||
Ray Serve offers two types of ``ServeHandle``. You can use the ``client.get_handle(..., sync=True|False)``
|
||||
flag to toggle between them.
|
||||
|
||||
- When you set ``sync=True`` (the default), a synchronous handle is returned.
|
||||
Calling ``handle.remote()`` should return a Ray ObjectRef.
|
||||
- When you set ``sync=False``, an asyncio based handle is returned. You need to
|
||||
Call it with ``await handle.remote()`` to return a Ray ObjectRef. To use ``await``,
|
||||
you have to run ``client.get_handle`` and ``handle.remote`` in Python asyncio event loop.
|
||||
|
||||
The async handle has performance advantage because it uses asyncio directly; as compared
|
||||
to the sync handle, which talks to an asyncio event loop in a thread. To learn more about
|
||||
the reasoning behind these, checkout our `architecture documentation <./architecture.html>`_.
|
||||
|
||||
|
||||
Monitoring
|
||||
==========
|
||||
|
||||
@@ -327,3 +347,11 @@ as shown below.
|
||||
:mod:`client.create_backend <ray.serve.api.Client.create_backend>` by
|
||||
default.
|
||||
|
||||
The dependencies required in the backend may be different than
|
||||
the dependencies installed in the driver program (the one running Serve API
|
||||
calls). In this case, you can use an
|
||||
:mod:`ImportedBackend <ray.serve.backends.ImportedBackend>` to specify a
|
||||
backend based on a class that is installed in the Python environment that
|
||||
the workers will run in. Example:
|
||||
|
||||
.. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py
|
||||
|
||||
+15
-22
@@ -117,7 +117,7 @@ policies <serve-split-traffic>`, finding the next available replica, and
|
||||
batching requests together.
|
||||
|
||||
When the request arrives in the model, you can access the data similarly to how
|
||||
you would with HTTP request. Here are some examples how ServeRequest mirrors Flask.Request:
|
||||
you would with HTTP request. Here are some examples how ServeRequest mirrors Starlette.Request:
|
||||
|
||||
.. list-table::
|
||||
:header-rows: 1
|
||||
@@ -125,25 +125,25 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla
|
||||
* - HTTP
|
||||
- ServeHandle
|
||||
- | Request
|
||||
| (Flask.Request and ServeRequest)
|
||||
| (Starlette.Request and ServeRequest)
|
||||
* - ``requests.get(..., headers={...})``
|
||||
- ``handle.options(http_headers={...})``
|
||||
- ``request.headers``
|
||||
* - ``requests.post(...)``
|
||||
- ``handle.options(http_method="POST")``
|
||||
- ``requests.method``
|
||||
* - ``request.get(..., json={...})``
|
||||
- ``request.method``
|
||||
* - ``requests.get(..., json={...})``
|
||||
- ``handle.remote({...})``
|
||||
- ``request.json``
|
||||
* - ``request.get(..., form={...})``
|
||||
- ``await request.json()``
|
||||
* - ``requests.get(..., form={...})``
|
||||
- ``handle.remote({...})``
|
||||
- ``request.form``
|
||||
* - ``request.get(..., params={"a":"b"})``
|
||||
- ``await request.form()``
|
||||
* - ``requests.get(..., params={"a":"b"})``
|
||||
- ``handle.remote(a="b")``
|
||||
- ``request.args``
|
||||
* - ``request.get(..., data="long string")``
|
||||
- ``request.query_params``
|
||||
* - ``requests.get(..., data="long string")``
|
||||
- ``handle.remote("long string")``
|
||||
- ``request.data``
|
||||
- ``await request.body()``
|
||||
* - ``N/A``
|
||||
- ``handle.remote(python_object)``
|
||||
- ``request.data``
|
||||
@@ -157,9 +157,9 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import flask
|
||||
import starlette.requests
|
||||
|
||||
if isinstance(request, flask.Request):
|
||||
if isinstance(request, starlette.requests.Request):
|
||||
print("Request coming from web!")
|
||||
elif isinstance(request, ServeRequest):
|
||||
print("Request coming from Python!")
|
||||
@@ -170,10 +170,10 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
handle.remote(flask_request)
|
||||
handle.remote(starlette_request)
|
||||
|
||||
In this case, Serve will `not` wrap it in ServeRequest. You can directly
|
||||
process the request as a ``flask.Request``.
|
||||
process the request as a ``starlette.requests.Request``.
|
||||
|
||||
How fast is Ray Serve?
|
||||
----------------------
|
||||
@@ -187,13 +187,6 @@ You can checkout our `microbenchmark instruction <https://github.com/ray-project
|
||||
to benchmark on your hardware.
|
||||
|
||||
|
||||
Does Ray Serve use Flask?
|
||||
-------------------------
|
||||
Flask is only used as a web request object for servable to consume the data.
|
||||
We actually use the fastest Python web server: `Uvicorn <https://www.uvicorn.org/>`_ as our web server,
|
||||
alongside with the power of Python asyncio.
|
||||
**Flask is ONLY the request object that we are using, Uvicorn (not flask) provides the webserver.**
|
||||
|
||||
Can I use asyncio along with Ray Serve?
|
||||
---------------------------------------
|
||||
Yes! You can make your servable methods ``async def`` and Serve will run them
|
||||
|
||||
@@ -33,6 +33,9 @@ Since Serve is built on Ray, it also allows you to scale to many machines, in yo
|
||||
If you want to try out Serve, join our `community slack <https://forms.gle/9TSdDYUgxYs8SA9e8>`_
|
||||
and discuss in the #serve channel.
|
||||
|
||||
.. note::
|
||||
Starting with Ray version 1.3.0, Ray Serve backends must take in a Starlette Request object instead of a Flask Request object.
|
||||
See the `migration guide <https://docs.google.com/document/d/1CG4y5WTTc4G_MRQGyjnb_eZ7GK3G9dUX6TNLKLnKRAc/edit?usp=sharing>`_ for details.
|
||||
|
||||
Installation
|
||||
============
|
||||
|
||||
@@ -19,10 +19,8 @@ Backends
|
||||
Backends define the implementation of your business logic or models that will handle requests when queries come in to :ref:`serve-endpoint`.
|
||||
In order to support seamless scalability backends can have many replicas, which are individual processes running in the Ray cluster to handle requests.
|
||||
To define a backend, first you must define the "handler" or the business logic you'd like to respond with.
|
||||
The handler should take as input a `Flask Request object <https://flask.palletsprojects.com/en/1.1.x/api/?highlight=request#flask.Request>`_.
|
||||
The handler should return any JSON-serializable object as output. For a more customizable response type, the handler may return a
|
||||
The handler should take as input a `Starlette Request object <https://www.starlette.io/requests/>`_ and return any JSON-serializable object as output. For a more customizable response type, the handler may return a
|
||||
`Starlette Response object <https://www.starlette.io/responses/>`_.
|
||||
In the future, Ray Serve will support `Starlette Request objects <https://www.starlette.io/requests/>`_ as input as well.
|
||||
|
||||
A backend is defined using :mod:`client.create_backend <ray.serve.api.Client.create_backend>`, and the implementation can be defined as either a function or a class.
|
||||
Use a function when your response is stateless and a class when you might need to maintain some state (like a model).
|
||||
@@ -32,7 +30,7 @@ A backend consists of a number of *replicas*, which are individual copies of the
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def handle_request(flask_request):
|
||||
def handle_request(starlette_request):
|
||||
return "hello world"
|
||||
|
||||
class RequestHandler:
|
||||
@@ -40,7 +38,7 @@ A backend consists of a number of *replicas*, which are individual copies of the
|
||||
def __init__(self, msg):
|
||||
self.msg = msg
|
||||
|
||||
def __call__(self, flask_request):
|
||||
def __call__(self, starlette_request):
|
||||
return self.msg
|
||||
|
||||
client.create_backend("simple_backend", handle_request)
|
||||
|
||||
@@ -23,7 +23,7 @@ Handle API
|
||||
:members: remote, options
|
||||
|
||||
When calling from Python, the backend implementation will receive ``ServeRequest``
|
||||
objects instead of Flask requests.
|
||||
objects instead of Starlette requests.
|
||||
|
||||
.. autoclass:: ray.serve.utils.ServeRequest
|
||||
:members:
|
||||
@@ -31,3 +31,6 @@ objects instead of Flask requests.
|
||||
Batching Requests
|
||||
-----------------
|
||||
.. autofunction:: ray.serve.accept_batch
|
||||
|
||||
Built-in Backends
|
||||
.. autoclass:: ray.serve.backends.ImportedBackend
|
||||
|
||||
@@ -30,13 +30,13 @@ You can use the ``@serve.accept_batch`` decorator to annotate a function or a cl
|
||||
This annotation is needed because batched backends have different APIs compared
|
||||
to single request backends. In a batched backend, the inputs are a list of values.
|
||||
|
||||
For single query backend, the input type is a single Flask request or
|
||||
For single query backend, the input type is a single Starlette request or
|
||||
:mod:`ServeRequest <ray.serve.utils.ServeRequest>`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
def single_request(
|
||||
request: Union[Flask.Request, ServeRequest],
|
||||
request: Union[starlette.requests.Request, ServeRequest],
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -47,7 +47,7 @@ types:
|
||||
|
||||
@serve.accept_batch
|
||||
def batched_request(
|
||||
request: List[Union[Flask.Request, ServeRequest]],
|
||||
request: List[Union[starlette.requests.Request, ServeRequest]],
|
||||
):
|
||||
pass
|
||||
|
||||
@@ -84,8 +84,8 @@ Ray Serve was able to evaluate them in batches.
|
||||
|
||||
What if you want to evaluate a whole batch in Python? Ray Serve allows you to send
|
||||
queries via the Python API. A batch of queries can either come from the web server
|
||||
or the Python API. Requests coming from the Python API will have the similar API
|
||||
as Flask.Request. See more on the API :ref:`here<serve-handle-explainer>`.
|
||||
or the Python API. Requests coming from the Python API will have a similar API
|
||||
to Starlette Request. See more on the API :ref:`here<serve-handle-explainer>`.
|
||||
|
||||
.. literalinclude:: ../../../../python/ray/serve/examples/doc/tutorial_batch.py
|
||||
:start-after: __doc_define_servable_v1_begin__
|
||||
|
||||
@@ -70,6 +70,11 @@ Take a look at any of the below tutorials to get started with Tune.
|
||||
:figure: /images/wandb_logo.png
|
||||
:description: :doc:`Track your experiment process with the Weights & Biases tools <tune-wandb>`
|
||||
|
||||
.. customgalleryitem::
|
||||
:tooltip: Use MLFlow with Ray Tune.
|
||||
:figure: /images/mlflow.png
|
||||
:description: :doc:`Log and track your hyperparameter sweep with MLFlow Tracking & AutoLogging <tune-mlflow>`
|
||||
|
||||
|
||||
.. raw:: html
|
||||
|
||||
@@ -81,12 +86,13 @@ Take a look at any of the below tutorials to get started with Tune.
|
||||
|
||||
tune-tutorial.rst
|
||||
tune-advanced-tutorial.rst
|
||||
tune-lifecycle.rst
|
||||
tune-distributed.rst
|
||||
tune-sklearn.rst
|
||||
tune-lifecycle.rst
|
||||
tune-mlflow.rst
|
||||
tune-pytorch-cifar.rst
|
||||
tune-pytorch-lightning.rst
|
||||
tune-serve-integration-mnist.rst
|
||||
tune-sklearn.rst
|
||||
tune-xgboost.rst
|
||||
tune-wandb.rst
|
||||
|
||||
@@ -156,4 +162,4 @@ Check out:
|
||||
|
||||
.. _tune-faq:
|
||||
|
||||
.. include:: _faq.rst
|
||||
.. include:: _faq.rst
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
.. _tune-mlflow:
|
||||
|
||||
Using MLFlow with Tune
|
||||
======================
|
||||
|
||||
`MLFlow <https://mlflow.org/>`_ is an open source platform to manage the ML lifecycle, including experimentation,
|
||||
reproducibility, deployment, and a central model registry. It currently offers four components, including
|
||||
MLFlow Tracking to record and query experiments, including code, data, config, and results.
|
||||
|
||||
.. image:: /images/mlflow.png
|
||||
:height: 80px
|
||||
:alt: MLflow
|
||||
:align: center
|
||||
:target: https://www.mlflow.org/
|
||||
|
||||
Ray Tune currently offers two lightweight integrations for MLFlow Tracking.
|
||||
One is the :ref:`MLFlowLoggerCallback <tune-mlflow-logger>`, which automatically logs
|
||||
metrics reported to Tune to the MLFlow Tracking API.
|
||||
|
||||
The other one is the :ref:`@mlflow_mixin <tune-mlflow-mixin>` decorator, which can be
|
||||
used with the function API. It automatically
|
||||
initializes the MLFlow API with Tune's training information and creates a run for each Tune trial.
|
||||
Then within your training function, you can just use the
|
||||
MLFlow like you would normally do, e.g. using ``mlflow.log_metrics()`` or even ``mlflow.autolog()``
|
||||
to log to your training process.
|
||||
|
||||
Please :doc:`see here </tune/examples/mlflow_example>` for a full example on how you can use either the
|
||||
MLFlowLoggerCallback or the mlflow_mixin.
|
||||
|
||||
MLFlow AutoLogging
|
||||
------------------
|
||||
You can also check out :doc:`here </tune/examples/mlflow_ptl_example>` for an example on how you can leverage MLflow
|
||||
autologging, in this case with Pytorch Lightning
|
||||
|
||||
MLFlow Logger API
|
||||
-----------------
|
||||
.. _tune-mlflow-logger:
|
||||
|
||||
.. autoclass:: ray.tune.integration.mlflow.MLFlowLoggerCallback
|
||||
:noindex:
|
||||
|
||||
MLFlow Mixin API
|
||||
----------------
|
||||
.. _tune-mlflow-mixin:
|
||||
|
||||
.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin
|
||||
:noindex:
|
||||
@@ -70,11 +70,7 @@ An example of creating a custom logger can be found in :doc:`/tune/examples/logg
|
||||
Trainable Logging
|
||||
-----------------
|
||||
|
||||
By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if:
|
||||
|
||||
* you're using `Weights and Biases <https://www.wandb.com/>`_
|
||||
* you're using `MLFlow <https://github.com/mlflow/mlflow/>`__
|
||||
* you're trying to log images to Tensorboard.
|
||||
By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if you're trying to log images to Tensorboard.
|
||||
|
||||
You can do this in the trainable, as shown below:
|
||||
|
||||
@@ -163,12 +159,17 @@ CSVLogger
|
||||
|
||||
.. autoclass:: ray.tune.logger.CSVLoggerCallback
|
||||
|
||||
MLFLowLogger
|
||||
MLFlowLogger
|
||||
------------
|
||||
|
||||
Tune also provides a default logger for `MLFlow <https://mlflow.org>`_. You can install MLFlow via ``pip install mlflow``. An example can be found in :doc:`/tune/examples/mlflow_example`. Note that this currently does not include artifact logging support. For this, you can use the native MLFlow APIs inside your Trainable definition.
|
||||
Tune also provides a default logger for `MLFlow <https://mlflow.org>`_. You can install MLFlow via ``pip install mlflow``.
|
||||
You can see the :doc:`tutorial here </tune/tutorials/tune-mlflow>`.
|
||||
|
||||
.. autoclass:: ray.tune.logger.MLFLowLogger
|
||||
WandbLogger
|
||||
-----------
|
||||
|
||||
Tune also provides a default logger for `Weights & Biases <https://www.wandb.com/>`_. You can install Wandb via ``pip install wandb``.
|
||||
You can see the :doc:`tutorial here </tune/tutorials/tune-wandb>`
|
||||
|
||||
|
||||
.. _logger-interface:
|
||||
|
||||
@@ -192,10 +192,18 @@ For a high-level overview, see this example:
|
||||
# Sample a integer uniformly between -9 (inclusive) and 15 (exclusive)
|
||||
"randint": tune.randint(-9, 15),
|
||||
|
||||
# Sample a integer uniformly between 1 (inclusive) and 10 (exclusive),
|
||||
# while sampling in log space
|
||||
"lograndint": tune.lograndint(1, 10),
|
||||
|
||||
# Sample a random uniformly between -21 (inclusive) and 12 (inclusive (!))
|
||||
# rounding to increments of 3 (includes 12)
|
||||
"qrandint": tune.qrandint(-21, 12, 3),
|
||||
|
||||
# Sample a integer uniformly between 1 (inclusive) and 10 (inclusive (!)),
|
||||
# while sampling in log space and rounding to increments of 2
|
||||
"qlograndint": tune.qlograndint(1, 10, 2),
|
||||
|
||||
# Sample an option uniformly from the specified choices
|
||||
"choice": tune.choice(["a", "b", "c"]),
|
||||
|
||||
@@ -263,10 +271,7 @@ Grid Search API
|
||||
|
||||
.. autofunction:: ray.tune.grid_search
|
||||
|
||||
Internals
|
||||
---------
|
||||
References
|
||||
----------
|
||||
|
||||
BasicVariantGenerator
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: ray.tune.suggest.BasicVariantGenerator
|
||||
See also :ref:`tune-basicvariant`.
|
||||
|
||||
@@ -22,6 +22,10 @@ Summary
|
||||
- Summary
|
||||
- Website
|
||||
- Code Example
|
||||
* - :ref:`Random search/grid search <tune-basicvariant>`
|
||||
- Random search/grid search
|
||||
-
|
||||
- :doc:`/tune/examples/tune_basic_example`
|
||||
* - :ref:`AxSearch <tune-ax>`
|
||||
- Bayesian/Bandit Optimization
|
||||
- [`Ax <https://ax.dev/>`__]
|
||||
@@ -123,6 +127,21 @@ identifier.
|
||||
|
||||
.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch.
|
||||
|
||||
.. _tune-basicvariant:
|
||||
|
||||
Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator)
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
The default and most basic way to do hyperparameter search is via random and grid search.
|
||||
Ray Tune does this through the :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>`
|
||||
class that generates trial variants given a search space definition.
|
||||
|
||||
The :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>` is used per
|
||||
default if no search algorithm is passed to
|
||||
:func:`tune.run() <ray.tune.run>`.
|
||||
|
||||
.. autoclass:: ray.tune.suggest.basic_variant.BasicVariantGenerator
|
||||
|
||||
.. _tune-ax:
|
||||
|
||||
Ax (tune.suggest.ax.AxSearch)
|
||||
|
||||
@@ -13,7 +13,7 @@ If any example is broken, or if you'd like to add an example to this page, feel
|
||||
General Examples
|
||||
----------------
|
||||
|
||||
|
||||
- :doc:`/tune/examples/tune_basic_example`: Simple example for doing a basic random and grid search.
|
||||
- :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler.
|
||||
- :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler.
|
||||
- :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler.
|
||||
@@ -88,6 +88,7 @@ Wandb, MLFlow
|
||||
- :ref:`Tutorial <tune-wandb>` for using `wandb <https://www.wandb.com/>`__ with Ray Tune
|
||||
- :doc:`/tune/examples/wandb_example`: Example for using `Weights and Biases <https://www.wandb.com/>`__ with Ray Tune.
|
||||
- :doc:`/tune/examples/mlflow_example`: Example for using `MLFlow <https://github.com/mlflow/mlflow/>`__ with Ray Tune.
|
||||
- :doc:`/tune/examples/mlflow_ptl_example`: Example for using `MLFlow <https://github.com/mlflow/mlflow/>`__ and `Pytorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ with Ray Tune.
|
||||
|
||||
Tensorflow/Keras
|
||||
~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
:orphan:
|
||||
|
||||
mlflow_ptl_example
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/mlflow_ptl.py
|
||||
@@ -0,0 +1,6 @@
|
||||
:orphan:
|
||||
|
||||
tune_basic_example
|
||||
~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. literalinclude:: /../../python/ray/tune/examples/tune_basic_example.py
|
||||
@@ -417,6 +417,8 @@ actors.
|
||||
to the object ref returned by the put exists. This only applies to the specific
|
||||
ref returned by put, not refs in general or copies of that refs.
|
||||
|
||||
See also: `object spilling <advanced.html#object-spilling>`__.
|
||||
|
||||
Remote Classes (Actors)
|
||||
-----------------------
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import java.util.Random;
|
||||
|
||||
public class ActorId extends BaseId implements Serializable {
|
||||
|
||||
private static final int UNIQUE_BYTES_LENGTH = 4;
|
||||
private static final int UNIQUE_BYTES_LENGTH = 12;
|
||||
|
||||
public static final int LENGTH = JobId.LENGTH + UNIQUE_BYTES_LENGTH;
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import java.util.Random;
|
||||
*/
|
||||
public class ObjectId extends BaseId implements Serializable {
|
||||
|
||||
public static final int LENGTH = 20;
|
||||
public static final int LENGTH = 28;
|
||||
|
||||
/**
|
||||
* Create an ObjectId from a ByteBuffer.
|
||||
|
||||
@@ -11,7 +11,7 @@ import java.util.Random;
|
||||
*/
|
||||
public class UniqueId extends BaseId implements Serializable {
|
||||
|
||||
public static final int LENGTH = 20;
|
||||
public static final int LENGTH = 28;
|
||||
public static final UniqueId NIL = genNil();
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
package io.ray.api.options;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
/**
|
||||
* The options class for RayCall or ActorCreation.
|
||||
*/
|
||||
public abstract class BaseTaskOptions {
|
||||
public abstract class BaseTaskOptions implements Serializable {
|
||||
|
||||
public final Map<String, Double> resources;
|
||||
|
||||
|
||||
@@ -72,9 +72,8 @@ public interface RayRuntime {
|
||||
*
|
||||
* @param objectRefs The object references to free.
|
||||
* @param localOnly Whether only free objects for local object store or not.
|
||||
* @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS.
|
||||
*/
|
||||
void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks);
|
||||
void free(List<ObjectRef<?>> objectRefs, boolean localOnly);
|
||||
|
||||
/**
|
||||
* Set the resource for the specific node.
|
||||
|
||||
@@ -100,9 +100,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly) {
|
||||
objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId()).collect(
|
||||
Collectors.toList()), localOnly, deleteCreatingTasks);
|
||||
Collectors.toList()), localOnly);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -3,10 +3,8 @@ package io.ray.runtime.gcs;
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.protobuf.InvalidProtocolBufferException;
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.JobId;
|
||||
import io.ray.api.id.PlacementGroupId;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.api.placementgroup.PlacementGroup;
|
||||
import io.ray.api.runtimecontext.NodeInfo;
|
||||
@@ -14,12 +12,10 @@ import io.ray.runtime.generated.Gcs;
|
||||
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
|
||||
import io.ray.runtime.generated.Gcs.TablePrefix;
|
||||
import io.ray.runtime.placementgroup.PlacementGroupUtils;
|
||||
import io.ray.runtime.util.IdUtil;
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.stream.Collectors;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
@@ -31,25 +27,10 @@ public class GcsClient {
|
||||
private static Logger LOGGER = LoggerFactory.getLogger(GcsClient.class);
|
||||
private RedisClient primary;
|
||||
|
||||
private List<RedisClient> shards;
|
||||
private GlobalStateAccessor globalStateAccessor;
|
||||
|
||||
public GcsClient(String redisAddress, String redisPassword) {
|
||||
primary = new RedisClient(redisAddress, redisPassword);
|
||||
int numShards = 0;
|
||||
try {
|
||||
numShards = Integer.valueOf(primary.get("NumRedisShards", null));
|
||||
Preconditions.checkState(numShards > 0,
|
||||
String.format("Expected at least one Redis shards, found %d.", numShards));
|
||||
} catch (NumberFormatException e) {
|
||||
throw new RuntimeException("Failed to get number of redis shards.", e);
|
||||
}
|
||||
|
||||
List<byte[]> shardAddresses = primary.lrange("RedisShards".getBytes(), 0, -1);
|
||||
Preconditions.checkState(shardAddresses.size() == numShards);
|
||||
shards = shardAddresses.stream().map((byte[] address) -> {
|
||||
return new RedisClient(new String(address), redisPassword);
|
||||
}).collect(Collectors.toList());
|
||||
globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword);
|
||||
}
|
||||
|
||||
@@ -163,16 +144,6 @@ public class GcsClient {
|
||||
return actorTableData.getNumRestarts() != 0;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query whether the raylet task exists in Gcs.
|
||||
*/
|
||||
public boolean rayletTaskExistsInGcs(TaskId taskId) {
|
||||
byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(),
|
||||
taskId.getBytes());
|
||||
RedisClient client = getShardClient(taskId);
|
||||
return client.exists(key);
|
||||
}
|
||||
|
||||
public JobId nextJobId() {
|
||||
int jobCounter = (int) primary.incr("JobCounter".getBytes());
|
||||
return JobId.fromInt(jobCounter);
|
||||
@@ -186,10 +157,4 @@ public class GcsClient {
|
||||
LOGGER.debug("Destroying global state accessor.");
|
||||
GlobalStateAccessor.destroyInstance();
|
||||
}
|
||||
|
||||
private RedisClient getShardClient(BaseId key) {
|
||||
return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key),
|
||||
shards.size()));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ public class LocalModeObjectStore extends ObjectStore {
|
||||
|
||||
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeObjectStore.class);
|
||||
|
||||
private static final int GET_CHECK_INTERVAL_MS = 100;
|
||||
private static final int GET_CHECK_INTERVAL_MS = 1;
|
||||
|
||||
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
|
||||
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
|
||||
@@ -93,7 +93,7 @@ public class LocalModeObjectStore extends ObjectStore {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly) {
|
||||
for (ObjectId objectId : objectIds) {
|
||||
pool.remove(objectId);
|
||||
}
|
||||
|
||||
@@ -50,8 +50,8 @@ public class NativeObjectStore extends ObjectStore {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
|
||||
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
|
||||
public void delete(List<ObjectId> objectIds, boolean localOnly) {
|
||||
nativeDelete(toBinaryList(objectIds), localOnly);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -116,8 +116,7 @@ public class NativeObjectStore extends ObjectStore {
|
||||
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
|
||||
long timeoutMs);
|
||||
|
||||
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly);
|
||||
|
||||
private static native void nativeAddLocalReference(byte[] workerId, byte[] objectId);
|
||||
|
||||
|
||||
@@ -167,10 +167,8 @@ public abstract class ObjectStore {
|
||||
*
|
||||
* @param objectIds IDs of the objects to delete.
|
||||
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
|
||||
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
|
||||
*/
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
|
||||
boolean deleteCreatingTasks);
|
||||
public abstract void delete(List<ObjectId> objectIds, boolean localOnly);
|
||||
|
||||
/**
|
||||
* Increase the local reference count for this object ID.
|
||||
|
||||
@@ -75,14 +75,18 @@ public class RunManager {
|
||||
// address info of the local node.
|
||||
String script = String.format("import ray;"
|
||||
+ " print(ray._private.services.get_address_info_from_redis("
|
||||
+ "'%s', '%s', redis_password='%s', no_warning=True))",
|
||||
+ "'%s', '%s', redis_password='%s'))",
|
||||
rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword);
|
||||
List<String> command = Arrays.asList("python", "-c", script);
|
||||
|
||||
String output = null;
|
||||
try {
|
||||
output = runCommand(command);
|
||||
JsonObject addressInfo = new JsonParser().parse(output).getAsJsonObject();
|
||||
// NOTE(kfstorm): We only parse the last line here in case there are some warning
|
||||
// messages appear at the beginning.
|
||||
String[] lines = output.split(System.lineSeparator());
|
||||
String lastLine = lines[lines.length - 1];
|
||||
JsonObject addressInfo = new JsonParser().parse(lastLine).getAsJsonObject();
|
||||
rayConfig.rayletSocketName = addressInfo.get("raylet_socket_name").getAsString();
|
||||
rayConfig.objectStoreSocketName = addressInfo.get("object_store_address").getAsString();
|
||||
rayConfig.nodeManagerPort = addressInfo.get("node_manager_port").getAsInt();
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package io.ray.runtime.util;
|
||||
|
||||
import io.ray.api.id.ActorId;
|
||||
import io.ray.api.id.BaseId;
|
||||
import io.ray.api.id.ObjectId;
|
||||
import io.ray.api.id.TaskId;
|
||||
|
||||
@@ -11,74 +10,6 @@ import io.ray.api.id.TaskId;
|
||||
*/
|
||||
public class IdUtil {
|
||||
|
||||
/**
|
||||
* Compute the murmur hash code of this ID.
|
||||
*/
|
||||
public static long murmurHashCode(BaseId id) {
|
||||
return murmurHash64A(id.getBytes(), id.size(), 0);
|
||||
}
|
||||
|
||||
/**
|
||||
* This method is the same as `Hash()` method of `ID` class in ray/src/ray/common/id.h
|
||||
*/
|
||||
private static long murmurHash64A(byte[] data, int length, int seed) {
|
||||
final long m = 0xc6a4a7935bd1e995L;
|
||||
final int r = 47;
|
||||
|
||||
long h = (seed & 0xFFFFFFFFL) ^ (length * m);
|
||||
|
||||
int length8 = length / 8;
|
||||
|
||||
for (int i = 0; i < length8; i++) {
|
||||
final int i8 = i * 8;
|
||||
long k = ((long) data[i8] & 0xff)
|
||||
+ (((long) data[i8 + 1] & 0xff) << 8)
|
||||
+ (((long) data[i8 + 2] & 0xff) << 16)
|
||||
+ (((long) data[i8 + 3] & 0xff) << 24)
|
||||
+ (((long) data[i8 + 4] & 0xff) << 32)
|
||||
+ (((long) data[i8 + 5] & 0xff) << 40)
|
||||
+ (((long) data[i8 + 6] & 0xff) << 48)
|
||||
+ (((long) data[i8 + 7] & 0xff) << 56);
|
||||
|
||||
k *= m;
|
||||
k ^= k >>> r;
|
||||
k *= m;
|
||||
|
||||
h ^= k;
|
||||
h *= m;
|
||||
}
|
||||
|
||||
final int remaining = length % 8;
|
||||
if (remaining >= 7) {
|
||||
h ^= (long) (data[(length & ~7) + 6] & 0xff) << 48;
|
||||
}
|
||||
if (remaining >= 6) {
|
||||
h ^= (long) (data[(length & ~7) + 5] & 0xff) << 40;
|
||||
}
|
||||
if (remaining >= 5) {
|
||||
h ^= (long) (data[(length & ~7) + 4] & 0xff) << 32;
|
||||
}
|
||||
if (remaining >= 4) {
|
||||
h ^= (long) (data[(length & ~7) + 3] & 0xff) << 24;
|
||||
}
|
||||
if (remaining >= 3) {
|
||||
h ^= (long) (data[(length & ~7) + 2] & 0xff) << 16;
|
||||
}
|
||||
if (remaining >= 2) {
|
||||
h ^= (long) (data[(length & ~7) + 1] & 0xff) << 8;
|
||||
}
|
||||
if (remaining >= 1) {
|
||||
h ^= (long) (data[length & ~7] & 0xff);
|
||||
h *= m;
|
||||
}
|
||||
|
||||
h ^= h >>> r;
|
||||
h *= m;
|
||||
h ^= h >>> r;
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the actor ID of the task which created this object.
|
||||
* @return The actor ID of the task which created this object.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package io.ray.runtime;
|
||||
|
||||
import io.ray.api.id.UniqueId;
|
||||
import io.ray.runtime.util.IdUtil;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.util.Arrays;
|
||||
import javax.xml.bind.DatatypeConverter;
|
||||
@@ -13,12 +12,12 @@ public class UniqueIdTest {
|
||||
@Test
|
||||
public void testConstructUniqueId() {
|
||||
// Test `fromHexString()`
|
||||
UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00");
|
||||
Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString());
|
||||
UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
|
||||
Assert.assertEquals("00000000123456789abcdef123456789abcdef0123456789abcdef00", id1.toString());
|
||||
Assert.assertFalse(id1.isNil());
|
||||
|
||||
try {
|
||||
UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF00");
|
||||
UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
|
||||
// This shouldn't be happened.
|
||||
Assert.assertTrue(false);
|
||||
} catch (IllegalArgumentException e) {
|
||||
@@ -34,23 +33,16 @@ public class UniqueIdTest {
|
||||
}
|
||||
|
||||
// Test `fromByteBuffer()`
|
||||
byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567");
|
||||
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20);
|
||||
byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF");
|
||||
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 28);
|
||||
UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer);
|
||||
Assert.assertTrue(Arrays.equals(bytes, id4.getBytes()));
|
||||
Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString());
|
||||
Assert.assertEquals("0123456789abcdef0123456789abcdef012345670123456789abcdef", id4.toString());
|
||||
|
||||
|
||||
// Test `genNil()`
|
||||
UniqueId id6 = UniqueId.NIL;
|
||||
Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString());
|
||||
Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString());
|
||||
Assert.assertTrue(id6.isNil());
|
||||
}
|
||||
|
||||
@Test
|
||||
void testMurmurHash() {
|
||||
UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232");
|
||||
long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000);
|
||||
Assert.assertEquals(remainder, 787616861);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,7 +128,7 @@ public class ActorTest extends BaseTest {
|
||||
ObjectRef value = counter.task(Counter::getValue).remote();
|
||||
Assert.assertEquals(100, value.get());
|
||||
// Delete the object from the object store.
|
||||
Ray.internal().free(ImmutableList.of(value), false, false);
|
||||
Ray.internal().free(ImmutableList.of(value), false);
|
||||
// Wait for delete RPC to propagate
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
// Free deletes from in-memory store.
|
||||
@@ -138,7 +138,7 @@ public class ActorTest extends BaseTest {
|
||||
ObjectRef<TestUtils.LargeObject> largeValue = counter.task(Counter::createLargeObject).remote();
|
||||
Assert.assertTrue(largeValue.get() instanceof TestUtils.LargeObject);
|
||||
// Delete the object from the object store.
|
||||
Ray.internal().free(ImmutableList.of(largeValue), false, false);
|
||||
Ray.internal().free(ImmutableList.of(largeValue), false);
|
||||
// Wait for delete RPC to propagate
|
||||
TimeUnit.SECONDS.sleep(1);
|
||||
// Free deletes big objects from plasma store.
|
||||
|
||||
@@ -15,7 +15,8 @@ public class DynamicResourceTest extends BaseTest {
|
||||
return "hi";
|
||||
}
|
||||
|
||||
@Test(groups = {"cluster"})
|
||||
// Dynamic resources not supported yet.
|
||||
@Test(groups = {"cluster"}, enabled = false)
|
||||
public void testSetResource() {
|
||||
// Call a task in advance to warm up the cluster to avoid being too slow to start workers.
|
||||
TestUtils.warmUpCluster();
|
||||
|
||||
@@ -3,9 +3,7 @@ package io.ray.test;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import io.ray.api.ObjectRef;
|
||||
import io.ray.api.Ray;
|
||||
import io.ray.api.id.TaskId;
|
||||
import io.ray.runtime.object.ObjectRefImpl;
|
||||
import java.util.Arrays;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
@@ -20,7 +18,7 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
ObjectRef<String> helloId = Ray.task(PlasmaFreeTest::hello).remote();
|
||||
String helloString = helloId.get();
|
||||
Assert.assertEquals("hello", helloString);
|
||||
Ray.internal().free(ImmutableList.of(helloId), true, false);
|
||||
Ray.internal().free(ImmutableList.of(helloId), true);
|
||||
|
||||
final boolean result = TestUtils.waitForCondition(() ->
|
||||
!TestUtils.getRuntime().getObjectStore()
|
||||
@@ -32,19 +30,4 @@ public class PlasmaFreeTest extends BaseTest {
|
||||
Assert.assertFalse(result);
|
||||
}
|
||||
}
|
||||
|
||||
@Test(groups = {"cluster"})
|
||||
public void testDeleteCreatingTasks() {
|
||||
ObjectRef<String> helloId = Ray.task(PlasmaFreeTest::hello).remote();
|
||||
Assert.assertEquals("hello", helloId.get());
|
||||
Ray.internal().free(ImmutableList.of(helloId), true, true);
|
||||
|
||||
TaskId taskId = TaskId.fromBytes(
|
||||
Arrays.copyOf(((ObjectRefImpl<String>) helloId).getId().getBytes(), TaskId.LENGTH));
|
||||
final boolean result = TestUtils.waitForCondition(
|
||||
() -> !TestUtils.getRuntime().getGcsClient()
|
||||
.rayletTaskExistsInGcs(taskId), 50);
|
||||
Assert.assertTrue(result);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
|
||||
|
||||
_client_hook_enabled = True
|
||||
|
||||
|
||||
def _enable_client_hook(val: bool):
|
||||
global _client_hook_enabled
|
||||
_client_hook_enabled = val
|
||||
|
||||
|
||||
def _disable_client_hook():
|
||||
global _client_hook_enabled
|
||||
out = _client_hook_enabled
|
||||
_client_hook_enabled = False
|
||||
return out
|
||||
|
||||
|
||||
def _explicitly_enable_client_mode():
|
||||
global client_mode_enabled
|
||||
client_mode_enabled = True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_client_hook():
|
||||
val = _disable_client_hook()
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
_enable_client_hook(val)
|
||||
|
||||
|
||||
def client_mode_hook(func):
|
||||
"""
|
||||
Decorator for ray module methods to delegate to ray client
|
||||
"""
|
||||
from ray.experimental.client import ray
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
global _client_hook_enabled
|
||||
if client_mode_enabled and _client_hook_enabled:
|
||||
return getattr(ray, func.__name__)(*args, **kwargs)
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -0,0 +1,83 @@
|
||||
import inspect
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
||||
|
||||
from ray._private.ray_microbenchmark_helpers import timeit
|
||||
from ray._private.ray_microbenchmark_helpers import ray_setup_and_teardown
|
||||
|
||||
|
||||
def benchmark_get_calls(ray):
|
||||
value = ray.put(0)
|
||||
|
||||
def get_small():
|
||||
ray.get(value)
|
||||
|
||||
timeit("client: get calls", get_small)
|
||||
|
||||
|
||||
def benchmark_put_calls(ray):
|
||||
def put_small():
|
||||
ray.put(0)
|
||||
|
||||
timeit("client: put calls", put_small)
|
||||
|
||||
|
||||
def benchmark_remote_put_calls(ray):
|
||||
@ray.remote
|
||||
def do_put_small():
|
||||
for _ in range(100):
|
||||
ray.put(0)
|
||||
|
||||
def put_multi_small():
|
||||
ray.get([do_put_small.remote() for _ in range(10)])
|
||||
|
||||
timeit("client: remote put calls", put_multi_small, 1000)
|
||||
|
||||
|
||||
def benchmark_simple_actor(ray):
|
||||
@ray.remote(num_cpus=0)
|
||||
class Actor:
|
||||
def small_value(self):
|
||||
return b"ok"
|
||||
|
||||
def small_value_arg(self, x):
|
||||
return b"ok"
|
||||
|
||||
def small_value_batch(self, n):
|
||||
ray.get([self.small_value.remote() for _ in range(n)])
|
||||
|
||||
a = Actor.remote()
|
||||
|
||||
def actor_sync():
|
||||
ray.get(a.small_value.remote())
|
||||
|
||||
timeit("client: 1:1 actor calls sync", actor_sync)
|
||||
|
||||
def actor_async():
|
||||
ray.get([a.small_value.remote() for _ in range(1000)])
|
||||
|
||||
timeit("client: 1:1 actor calls async", actor_async, 1000)
|
||||
|
||||
a = Actor.options(max_concurrency=16).remote()
|
||||
|
||||
def actor_concurrent():
|
||||
ray.get([a.small_value.remote() for _ in range(1000)])
|
||||
|
||||
timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000)
|
||||
|
||||
|
||||
def main():
|
||||
system_config = {"put_small_object_in_memory_store": True}
|
||||
with ray_setup_and_teardown(
|
||||
logging_level=logging.WARNING, _system_config=system_config):
|
||||
for name, obj in inspect.getmembers(sys.modules[__name__]):
|
||||
if not name.startswith("benchmark_"):
|
||||
continue
|
||||
with ray_start_client_server() as ray:
|
||||
obj(ray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,39 @@
|
||||
import time
|
||||
import os
|
||||
import ray
|
||||
import numpy as np
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
# Only run tests matching this filter pattern.
|
||||
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
|
||||
|
||||
|
||||
def timeit(name, fn, multiplier=1):
|
||||
if filter_pattern not in name:
|
||||
return
|
||||
# warmup
|
||||
start = time.time()
|
||||
while time.time() - start < 1:
|
||||
fn()
|
||||
# real run
|
||||
stats = []
|
||||
for _ in range(4):
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 2:
|
||||
fn()
|
||||
count += 1
|
||||
end = time.time()
|
||||
stats.append(multiplier * count / (end - start))
|
||||
print(name, "per second", round(np.mean(stats), 2), "+-",
|
||||
round(np.std(stats), 2))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_setup_and_teardown(**init_args):
|
||||
ray.init(**init_args)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
ray.shutdown()
|
||||
@@ -279,8 +279,7 @@ def get_address_info_from_redis_helper(redis_address,
|
||||
def get_address_info_from_redis(redis_address,
|
||||
node_ip_address,
|
||||
num_retries=5,
|
||||
redis_password=None,
|
||||
no_warning=False):
|
||||
redis_password=None):
|
||||
counter = 0
|
||||
while True:
|
||||
try:
|
||||
@@ -291,11 +290,10 @@ def get_address_info_from_redis(redis_address,
|
||||
raise
|
||||
# Some of the information may not be in Redis yet, so wait a little
|
||||
# bit.
|
||||
if not no_warning:
|
||||
logger.warning(
|
||||
"Some processes that the driver needs to connect to have "
|
||||
"not registered with Redis, so retrying. Have you run "
|
||||
"'ray start' on this node?")
|
||||
logger.warning(
|
||||
"Some processes that the driver needs to connect to have "
|
||||
"not registered with Redis, so retrying. Have you run "
|
||||
"'ray start' on this node?")
|
||||
time.sleep(1)
|
||||
counter += 1
|
||||
|
||||
@@ -1618,12 +1616,13 @@ def determine_plasma_store_config(object_store_memory,
|
||||
logger.warning(
|
||||
"WARNING: The object store is using {} instead of "
|
||||
"/dev/shm because /dev/shm has only {} bytes available. "
|
||||
"This may slow down performance! You may be able to free "
|
||||
"up space by deleting files in /dev/shm or terminating "
|
||||
"any running plasma_store_server processes. If you are "
|
||||
"inside a Docker container, you may need to pass an "
|
||||
"argument with the flag '--shm-size' to 'docker run'.".
|
||||
format(ray.utils.get_user_temp_dir(), shm_avail))
|
||||
"This will harm performance! You may be able to free up "
|
||||
"space by deleting files in /dev/shm. If you are inside a "
|
||||
"Docker container, you can increase /dev/shm size by "
|
||||
"passing '--shm-size=Xgb' to 'docker run' (or add it to "
|
||||
"the run_options list in a Ray cluster config). Make sure "
|
||||
"to set this to more than 2gb.".format(
|
||||
ray.utils.get_user_temp_dir(), shm_avail))
|
||||
else:
|
||||
plasma_directory = ray.utils.get_user_temp_dir()
|
||||
|
||||
|
||||
+19
-19
@@ -107,6 +107,10 @@ from ray.exceptions import (
|
||||
TaskCancelledError
|
||||
)
|
||||
from ray.utils import decode
|
||||
from ray._private.client_mode_hook import (
|
||||
_enable_client_hook,
|
||||
_disable_client_hook,
|
||||
)
|
||||
import msgpack
|
||||
|
||||
cimport cpython
|
||||
@@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler(
|
||||
|
||||
with gil:
|
||||
try:
|
||||
client_was_enabled = _disable_client_hook()
|
||||
try:
|
||||
# The call to execute_task should never raise an exception. If
|
||||
# it does, that indicates that there was an internal error.
|
||||
@@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler(
|
||||
else:
|
||||
logger.exception("SystemExit was raised from the worker")
|
||||
return CRayStatus.UnexpectedSystemExit()
|
||||
finally:
|
||||
_enable_client_hook(client_was_enabled)
|
||||
|
||||
return CRayStatus.OK()
|
||||
|
||||
@@ -638,9 +645,11 @@ cdef c_vector[c_string] spill_objects_handler(
|
||||
return return_urls
|
||||
|
||||
|
||||
cdef void restore_spilled_objects_handler(
|
||||
cdef int64_t restore_spilled_objects_handler(
|
||||
const c_vector[CObjectID]& object_ids_to_restore,
|
||||
const c_vector[c_string]& object_urls) nogil:
|
||||
cdef:
|
||||
int64_t bytes_restored = 0
|
||||
with gil:
|
||||
urls = []
|
||||
size = object_urls.size()
|
||||
@@ -651,7 +660,8 @@ cdef void restore_spilled_objects_handler(
|
||||
with ray.worker._changeproctitle(
|
||||
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER,
|
||||
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE):
|
||||
external_storage.restore_spilled_objects(object_refs, urls)
|
||||
bytes_restored = external_storage.restore_spilled_objects(
|
||||
object_refs, urls)
|
||||
except Exception:
|
||||
exception_str = (
|
||||
"An unexpected internal error occurred while the IO worker "
|
||||
@@ -662,6 +672,7 @@ cdef void restore_spilled_objects_handler(
|
||||
"restore_spilled_objects_error",
|
||||
traceback.format_exc() + exception_str,
|
||||
job_id=None)
|
||||
return bytes_restored
|
||||
|
||||
|
||||
cdef void delete_spilled_objects_handler(
|
||||
@@ -873,7 +884,8 @@ cdef class CoreWorker:
|
||||
return self.plasma_event_handler
|
||||
|
||||
def get_objects(self, object_refs, TaskID current_task_id,
|
||||
int64_t timeout_ms=-1, plasma_objects_only=False):
|
||||
int64_t timeout_ms=-1,
|
||||
plasma_objects_only=False):
|
||||
cdef:
|
||||
c_vector[shared_ptr[CRayObject]] results
|
||||
CTaskID c_task_id = current_task_id.native()
|
||||
@@ -1004,7 +1016,7 @@ cdef class CoreWorker:
|
||||
return c_object_id.Binary()
|
||||
|
||||
def wait(self, object_refs, int num_returns, int64_t timeout_ms,
|
||||
TaskID current_task_id):
|
||||
TaskID current_task_id, c_bool fetch_local):
|
||||
cdef:
|
||||
c_vector[CObjectID] wait_ids
|
||||
c_vector[c_bool] results
|
||||
@@ -1013,7 +1025,7 @@ cdef class CoreWorker:
|
||||
wait_ids = ObjectRefsToVector(object_refs)
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
|
||||
wait_ids, num_returns, timeout_ms, &results))
|
||||
wait_ids, num_returns, timeout_ms, &results, fetch_local))
|
||||
|
||||
assert len(results) == len(object_refs)
|
||||
|
||||
@@ -1026,14 +1038,13 @@ cdef class CoreWorker:
|
||||
|
||||
return ready, not_ready
|
||||
|
||||
def free_objects(self, object_refs, c_bool local_only,
|
||||
c_bool delete_creating_tasks):
|
||||
def free_objects(self, object_refs, c_bool local_only):
|
||||
cdef:
|
||||
c_vector[CObjectID] free_ids = ObjectRefsToVector(object_refs)
|
||||
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker().Delete(
|
||||
free_ids, local_only, delete_creating_tasks))
|
||||
free_ids, local_only))
|
||||
|
||||
def global_gc(self):
|
||||
with nogil:
|
||||
@@ -1573,17 +1584,6 @@ cdef class CoreWorker:
|
||||
resource_name.encode("ascii"), capacity,
|
||||
CNodeID.FromBinary(client_id.binary()))
|
||||
|
||||
def force_spill_objects(self, object_refs):
|
||||
cdef c_vector[CObjectID] object_ids
|
||||
object_ids = ObjectRefsToVector(object_refs)
|
||||
assert not RayConfig.instance().automatic_object_deletion_enabled(), (
|
||||
"Automatic object deletion is not supported for"
|
||||
"force_spill_objects yet. Please set"
|
||||
"automatic_object_deletion_enabled: False in Ray's system config.")
|
||||
with nogil:
|
||||
check_status(CCoreWorkerProcess.GetCoreWorker()
|
||||
.SpillObjects(object_ids))
|
||||
|
||||
cdef void async_set_result(shared_ptr[CRayObject] obj,
|
||||
CObjectID object_ref,
|
||||
void *future) with gil:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections import defaultdict, namedtuple, Counter
|
||||
from typing import Any, Optional, Dict, List
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
import copy
|
||||
@@ -16,8 +16,10 @@ from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
from ray.autoscaler.tags import (
|
||||
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
|
||||
TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND,
|
||||
TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE, NODE_KIND_WORKER,
|
||||
NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
|
||||
TAG_RAY_USER_NODE_TYPE, STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
|
||||
STATUS_SYNCING_FILES, STATUS_SETTING_UP, STATUS_UP_TO_DATE,
|
||||
NODE_KIND_WORKER, NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
|
||||
from ray.autoscaler._private.legacy_info_string import legacy_log_info_string
|
||||
from ray.autoscaler._private.providers import _get_node_provider
|
||||
from ray.autoscaler._private.updater import NodeUpdaterThread
|
||||
from ray.autoscaler._private.node_launcher import NodeLauncher
|
||||
@@ -25,8 +27,8 @@ from ray.autoscaler._private.resource_demand_scheduler import \
|
||||
get_bin_pack_residual, ResourceDemandScheduler, NodeType, NodeID, NodeIP, \
|
||||
ResourceDict
|
||||
from ray.autoscaler._private.util import ConcurrentCounter, validate_config, \
|
||||
with_head_node_ip, hash_launch_conf, hash_runtime_conf, add_prefix, \
|
||||
DEBUG_AUTOSCALING_STATUS, DEBUG_AUTOSCALING_ERROR
|
||||
with_head_node_ip, hash_launch_conf, hash_runtime_conf, \
|
||||
DEBUG_AUTOSCALING_ERROR, format_info_string
|
||||
from ray.autoscaler._private.constants import \
|
||||
AUTOSCALER_MAX_NUM_FAILURES, AUTOSCALER_MAX_LAUNCH_BATCH, \
|
||||
AUTOSCALER_MAX_CONCURRENT_LAUNCHES, AUTOSCALER_UPDATE_INTERVAL_S, \
|
||||
@@ -41,20 +43,23 @@ UpdateInstructions = namedtuple(
|
||||
"UpdateInstructions",
|
||||
["node_id", "init_commands", "start_ray_commands", "docker_config"])
|
||||
|
||||
AutoscalerSummary = namedtuple(
|
||||
"AutoscalerSummary",
|
||||
["active_nodes", "pending_nodes", "pending_launches", "failed_nodes"])
|
||||
|
||||
|
||||
class StandardAutoscaler:
|
||||
"""The autoscaling control loop for a Ray cluster.
|
||||
|
||||
There are two ways to start an autoscaling cluster: manually by running
|
||||
`ray start --head --autoscaling-config=/path/to/config.yaml` on a
|
||||
instance that has permission to launch other instances, or you can also use
|
||||
`ray up /path/to/config.yaml` from your laptop, which will
|
||||
configure the right AWS/Cloud roles automatically.
|
||||
|
||||
StandardAutoscaler's `update` method is periodically called by `monitor.py`
|
||||
to add and remove nodes as necessary. Currently, load-based autoscaling is
|
||||
not implemented, so all this class does is try to maintain a constant
|
||||
cluster size.
|
||||
`ray start --head --autoscaling-config=/path/to/config.yaml` on a instance
|
||||
that has permission to launch other instances, or you can also use `ray up
|
||||
/path/to/config.yaml` from your laptop, which will configure the right
|
||||
AWS/Cloud roles automatically. See the documentation for a full definition
|
||||
of autoscaling behavior:
|
||||
https://docs.ray.io/en/master/cluster/autoscaling.html
|
||||
StandardAutoscaler's `update` method is periodically called in
|
||||
`monitor.py`'s monitoring loop.
|
||||
|
||||
StandardAutoscaler is also used to bootstrap clusters (by adding workers
|
||||
until the cluster size that can handle the resource demand is met).
|
||||
@@ -120,9 +125,6 @@ class StandardAutoscaler:
|
||||
for local_path in self.config["file_mounts"].values():
|
||||
assert os.path.exists(local_path)
|
||||
|
||||
# List of resource bundles the user is requesting of the cluster.
|
||||
self.resource_demand_vector = []
|
||||
|
||||
logger.info("StandardAutoscaler: {}".format(self.config))
|
||||
|
||||
def update(self):
|
||||
@@ -149,7 +151,6 @@ class StandardAutoscaler:
|
||||
|
||||
def _update(self):
|
||||
now = time.time()
|
||||
|
||||
# Throttle autoscaling updates to this interval to avoid exceeding
|
||||
# rate limits on API calls.
|
||||
if now - self.last_update_time < self.update_interval_s:
|
||||
@@ -162,7 +163,6 @@ class StandardAutoscaler:
|
||||
self.provider.internal_ip(node_id)
|
||||
for node_id in self.all_workers()
|
||||
])
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Terminate any idle or out of date nodes
|
||||
last_used = self.load_metrics.last_used_time_by_ip
|
||||
@@ -176,7 +176,7 @@ class StandardAutoscaler:
|
||||
sorted_node_ids = self._sort_based_on_last_used(nodes, last_used)
|
||||
# Don't terminate nodes needed by request_resources()
|
||||
nodes_allowed_to_terminate: Dict[NodeID, bool] = {}
|
||||
if self.resource_demand_vector:
|
||||
if self.load_metrics.get_resource_requests():
|
||||
nodes_allowed_to_terminate = self._get_nodes_allowed_to_terminate(
|
||||
sorted_node_ids)
|
||||
|
||||
@@ -202,7 +202,6 @@ class StandardAutoscaler:
|
||||
if nodes_to_terminate:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Terminate nodes if there are too many
|
||||
nodes_to_terminate = []
|
||||
@@ -217,8 +216,6 @@ class StandardAutoscaler:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
nodes = self.workers()
|
||||
|
||||
self.log_info_string(nodes)
|
||||
|
||||
to_launch = self.resource_demand_scheduler.get_nodes_to_launch(
|
||||
self.provider.non_terminated_nodes(tag_filters={}),
|
||||
self.pending_launches.breakdown(),
|
||||
@@ -226,7 +223,7 @@ class StandardAutoscaler:
|
||||
self.load_metrics.get_resource_utilization(),
|
||||
self.load_metrics.get_pending_placement_groups(),
|
||||
self.load_metrics.get_static_node_resources_by_ip(),
|
||||
ensure_min_cluster_size=self.resource_demand_vector)
|
||||
ensure_min_cluster_size=self.load_metrics.get_resource_requests())
|
||||
for node_type, count in to_launch.items():
|
||||
self.launch_new_node(count, node_type=node_type)
|
||||
|
||||
@@ -256,7 +253,6 @@ class StandardAutoscaler:
|
||||
self.provider.terminate_nodes(nodes_to_terminate)
|
||||
|
||||
nodes = self.workers()
|
||||
self.log_info_string(nodes)
|
||||
|
||||
# Update nodes with out-of-date files.
|
||||
# TODO(edoakes): Spawning these threads directly seems to cause
|
||||
@@ -282,6 +278,9 @@ class StandardAutoscaler:
|
||||
for node_id in nodes:
|
||||
self.recover_if_needed(node_id, now)
|
||||
|
||||
logger.info(self.info_string())
|
||||
legacy_log_info_string(self, nodes)
|
||||
|
||||
def _sort_based_on_last_used(self, nodes: List[NodeID],
|
||||
last_used: Dict[str, float]) -> List[NodeID]:
|
||||
"""Sort the nodes based on the last time they were used.
|
||||
@@ -333,7 +332,7 @@ class StandardAutoscaler:
|
||||
NodeIP,
|
||||
ResourceDict] = \
|
||||
self.load_metrics.get_static_node_resources_by_ip()
|
||||
head_node_resources = static_nodes[head_ip]
|
||||
head_node_resources = static_nodes.get(head_ip, {})
|
||||
else:
|
||||
head_node_resources = {}
|
||||
|
||||
@@ -362,7 +361,7 @@ class StandardAutoscaler:
|
||||
used_resource_requests: List[ResourceDict]
|
||||
_, used_resource_requests = \
|
||||
get_bin_pack_residual(max_node_resources,
|
||||
self.resource_demand_vector)
|
||||
self.load_metrics.get_resource_requests())
|
||||
# Remove the first entry (the head node).
|
||||
max_node_resources.pop(0)
|
||||
# Remove the first entry (the head node).
|
||||
@@ -482,11 +481,13 @@ class StandardAutoscaler:
|
||||
# for legacy yamls.
|
||||
self.resource_demand_scheduler.reset_config(
|
||||
self.provider, self.available_node_types,
|
||||
self.config["max_workers"], upscaling_speed)
|
||||
self.config["max_workers"], self.config["head_node_type"],
|
||||
upscaling_speed)
|
||||
else:
|
||||
self.resource_demand_scheduler = ResourceDemandScheduler(
|
||||
self.provider, self.available_node_types,
|
||||
self.config["max_workers"], upscaling_speed)
|
||||
self.config["max_workers"], self.config["head_node_type"],
|
||||
upscaling_speed)
|
||||
|
||||
except Exception as e:
|
||||
if errors_fatal:
|
||||
@@ -532,15 +533,17 @@ class StandardAutoscaler:
|
||||
if not self.can_update(node_id):
|
||||
return
|
||||
key = self.provider.internal_ip(node_id)
|
||||
if key not in self.load_metrics.last_heartbeat_time_by_ip:
|
||||
self.load_metrics.last_heartbeat_time_by_ip[key] = now
|
||||
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key]
|
||||
delta = now - last_heartbeat_time
|
||||
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
|
||||
return
|
||||
|
||||
if key in self.load_metrics.last_heartbeat_time_by_ip:
|
||||
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[
|
||||
key]
|
||||
delta = now - last_heartbeat_time
|
||||
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
|
||||
return
|
||||
|
||||
logger.warning("StandardAutoscaler: "
|
||||
"{}: No heartbeat in {}s, "
|
||||
"restarting Ray to recover...".format(node_id, delta))
|
||||
"{}: No recent heartbeat, "
|
||||
"restarting Ray to recover...".format(node_id))
|
||||
updater = NodeUpdaterThread(
|
||||
node_id=node_id,
|
||||
provider_config=self.config["provider"],
|
||||
@@ -677,43 +680,6 @@ class StandardAutoscaler:
|
||||
return self.provider.non_terminated_nodes(
|
||||
tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_UNMANAGED})
|
||||
|
||||
def log_info_string(self, nodes):
|
||||
tmp = "Cluster status: "
|
||||
tmp += self.info_string(nodes)
|
||||
tmp += "\n"
|
||||
tmp += self.load_metrics.info_string()
|
||||
tmp += "\n"
|
||||
tmp += self.resource_demand_scheduler.debug_string(
|
||||
nodes, self.pending_launches.breakdown(),
|
||||
self.load_metrics.get_resource_utilization())
|
||||
if _internal_kv_initialized():
|
||||
_internal_kv_put(DEBUG_AUTOSCALING_STATUS, tmp, overwrite=True)
|
||||
if self.prefix_cluster_info:
|
||||
tmp = add_prefix(tmp, self.config["cluster_name"])
|
||||
logger.debug(tmp)
|
||||
|
||||
def info_string(self, nodes):
|
||||
suffix = ""
|
||||
if self.updaters:
|
||||
suffix += " ({} updating)".format(len(self.updaters))
|
||||
if self.num_failed_updates:
|
||||
suffix += " ({} failed to update)".format(
|
||||
len(self.num_failed_updates))
|
||||
|
||||
return "{} nodes{}".format(len(nodes), suffix)
|
||||
|
||||
def request_resources(self, resources: List[dict]):
|
||||
"""Called by monitor to request resources.
|
||||
|
||||
Args:
|
||||
resources: A list of resource bundles.
|
||||
"""
|
||||
if resources:
|
||||
logger.info(
|
||||
"StandardAutoscaler: resource_requests={}".format(resources))
|
||||
assert isinstance(resources, list), resources
|
||||
self.resource_demand_vector = resources
|
||||
|
||||
def kill_workers(self):
|
||||
logger.error("StandardAutoscaler: kill_workers triggered")
|
||||
nodes = self.workers()
|
||||
@@ -721,3 +687,66 @@ class StandardAutoscaler:
|
||||
self.provider.terminate_nodes(nodes)
|
||||
logger.error("StandardAutoscaler: terminated {} node(s)".format(
|
||||
len(nodes)))
|
||||
|
||||
def summary(self):
|
||||
"""Summarizes the active, pending, and failed node launches.
|
||||
|
||||
An active node is a node whose raylet is actively reporting heartbeats.
|
||||
A pending node is non-active node whose node tag is uninitialized,
|
||||
waiting for ssh, syncing files, or setting up.
|
||||
If a node is not pending or active, it is failed.
|
||||
|
||||
Returns:
|
||||
AutoscalerSummary: The summary.
|
||||
"""
|
||||
all_node_ids = self.provider.non_terminated_nodes(tag_filters={})
|
||||
|
||||
active_nodes = Counter()
|
||||
pending_nodes = []
|
||||
failed_nodes = []
|
||||
|
||||
for node_id in all_node_ids:
|
||||
ip = self.provider.internal_ip(node_id)
|
||||
node_tags = self.provider.node_tags(node_id)
|
||||
if node_tags[TAG_RAY_NODE_KIND] == NODE_KIND_UNMANAGED:
|
||||
continue
|
||||
node_type = node_tags[TAG_RAY_USER_NODE_TYPE]
|
||||
|
||||
# TODO (Alex): If a node's raylet has died, it shouldn't be marked
|
||||
# as active.
|
||||
is_active = self.load_metrics.is_active(ip)
|
||||
if is_active:
|
||||
active_nodes[node_type] += 1
|
||||
else:
|
||||
status = node_tags[TAG_RAY_NODE_STATUS]
|
||||
pending_states = [
|
||||
STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
|
||||
STATUS_SYNCING_FILES, STATUS_SETTING_UP
|
||||
]
|
||||
is_pending = status in pending_states
|
||||
if is_pending:
|
||||
pending_nodes.append((ip, node_type))
|
||||
else:
|
||||
# TODO (Alex): Failed nodes are now immediately killed, so
|
||||
# this list will almost always be empty. We should ideally
|
||||
# keep a cache of recently failed nodes and their startup
|
||||
# logs.
|
||||
failed_nodes.append((ip, node_type))
|
||||
|
||||
# The concurrent counter leaves some 0 counts in, so we need to
|
||||
# manually filter those out.
|
||||
pending_launches = {}
|
||||
for node_type, count in self.pending_launches.breakdown().items():
|
||||
if count:
|
||||
pending_launches[node_type] = count
|
||||
|
||||
return AutoscalerSummary(
|
||||
active_nodes=active_nodes,
|
||||
pending_nodes=pending_nodes,
|
||||
pending_launches=pending_launches,
|
||||
failed_nodes=failed_nodes)
|
||||
|
||||
def info_string(self):
|
||||
lm_summary = self.load_metrics.summary()
|
||||
autoscaler_summary = self.summary()
|
||||
return "\n" + format_info_string(lm_summary, autoscaler_summary)
|
||||
|
||||
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
|
||||
HASH_MAX_LENGTH = 10
|
||||
KUBECTL_RSYNC = os.path.join(
|
||||
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
|
||||
MAX_HOME_RETRIES = 3
|
||||
HOME_RETRY_DELAY_S = 5
|
||||
|
||||
_config = {"use_login_shells": True, "silent_rsync": True}
|
||||
|
||||
@@ -248,16 +250,31 @@ class KubernetesCommandRunner(CommandRunnerInterface):
|
||||
|
||||
@property
|
||||
def _home(self):
|
||||
if self._home_cached is not None:
|
||||
return self._home_cached
|
||||
for _ in range(MAX_HOME_RETRIES - 1):
|
||||
try:
|
||||
self._home_cached = self._try_to_get_home()
|
||||
return self._home_cached
|
||||
except Exception:
|
||||
# TODO (Dmitri): Identify the exception we're trying to avoid.
|
||||
logger.info("Error reading container's home directory. "
|
||||
f"Retrying in {HOME_RETRY_DELAY_S} seconds.")
|
||||
time.sleep(HOME_RETRY_DELAY_S)
|
||||
# Last try
|
||||
self._home_cached = self._try_to_get_home()
|
||||
return self._home_cached
|
||||
|
||||
def _try_to_get_home(self):
|
||||
# TODO (Dmitri): Think about how to use the node's HOME variable
|
||||
# without making an extra kubectl exec call.
|
||||
if self._home_cached is None:
|
||||
cmd = self.kubectl + [
|
||||
"exec", "-it", self.node_id, "--", "printenv", "HOME"
|
||||
]
|
||||
joined_cmd = " ".join(cmd)
|
||||
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
|
||||
self._home_cached = raw_out.decode().strip("\n\r")
|
||||
return self._home_cached
|
||||
cmd = self.kubectl + [
|
||||
"exec", "-it", self.node_id, "--", "printenv", "HOME"
|
||||
]
|
||||
joined_cmd = " ".join(cmd)
|
||||
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
|
||||
home = raw_out.decode().strip("\n\r")
|
||||
return home
|
||||
|
||||
|
||||
class SSHOptions:
|
||||
|
||||
@@ -43,6 +43,10 @@ from ray.worker import global_worker # type: ignore
|
||||
from ray.util.debug import log_once
|
||||
|
||||
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
|
||||
from ray.autoscaler._private.load_metrics import LoadMetricsSummary
|
||||
from ray.autoscaler._private.autoscaler import AutoscalerSummary
|
||||
from ray.autoscaler._private.util import format_info_string, \
|
||||
format_info_string_no_node_types
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -94,6 +98,14 @@ def debug_status() -> str:
|
||||
status = "No cluster status."
|
||||
else:
|
||||
status = status.decode("utf-8")
|
||||
as_dict = json.loads(status)
|
||||
lm_summary = LoadMetricsSummary(**as_dict["load_metrics_report"])
|
||||
if "autoscaler_report" in as_dict:
|
||||
autoscaler_summary = AutoscalerSummary(
|
||||
**as_dict["autoscaler_report"])
|
||||
status = format_info_string(lm_summary, autoscaler_summary)
|
||||
else:
|
||||
status = format_info_string_no_node_types(lm_summary)
|
||||
if error:
|
||||
status += "\n"
|
||||
status += error.decode("utf-8")
|
||||
@@ -280,9 +292,10 @@ def _bootstrap_config(config: Dict[str, Any],
|
||||
f"Failed to autodetect node resources: {str(exc)}. "
|
||||
"You can see full stack trace with higher verbosity.")
|
||||
|
||||
# NOTE: if `resources` field is missing, validate_config for non-AWS will
|
||||
# fail (the schema error will ask the user to manually fill the resources)
|
||||
# as we currently support autofilling resources for AWS instances only.
|
||||
# NOTE: if `resources` field is missing, validate_config for providers
|
||||
# other than AWS and Kubernetes will fail (the schema error will ask the
|
||||
# user to manually fill the resources) as we currently support autofilling
|
||||
# resources for AWS and Kubernetes only.
|
||||
validate_config(config)
|
||||
resolved_config = provider_cls.bootstrap_config(config)
|
||||
|
||||
|
||||
@@ -60,6 +60,13 @@ def bootstrap_kubernetes(config):
|
||||
|
||||
|
||||
def fillout_resources_kubernetes(config):
|
||||
"""Fills CPU and GPU resources by reading pod spec of each available node
|
||||
type.
|
||||
|
||||
For each node type and each of CPU/GPU, looks at container's resources
|
||||
and limits, takes min of the two. The result is rounded up, as Ray does
|
||||
not currently support fractional CPU.
|
||||
"""
|
||||
if "available_node_types" not in config:
|
||||
return config["available_node_types"]
|
||||
node_types = copy.deepcopy(config["available_node_types"])
|
||||
@@ -96,20 +103,47 @@ def get_resource(container_resources, resource_name):
|
||||
limit = _get_resource(
|
||||
container_resources, resource_name, field_name="limits")
|
||||
resource = min(request, limit)
|
||||
# float("inf") value means the resource wasn't detected in either
|
||||
# requests or limits
|
||||
return 0 if resource == float("inf") else int(resource)
|
||||
|
||||
|
||||
def _get_resource(container_resources, resource_name, field_name):
|
||||
if (field_name in container_resources
|
||||
and resource_name in container_resources[field_name]):
|
||||
return _parse_resource(container_resources[field_name][resource_name])
|
||||
else:
|
||||
"""Returns the resource quantity.
|
||||
|
||||
The amount of resource is rounded up to nearest integer.
|
||||
Returns float("inf") if the resource is not present.
|
||||
|
||||
Args:
|
||||
container_resources (dict): Container's resource field.
|
||||
resource_name (str): One of 'cpu' or 'gpu'.
|
||||
field_name (str): One of 'requests' or 'limits'.
|
||||
|
||||
Returns:
|
||||
Union[int, float]: Detected resource quantity.
|
||||
"""
|
||||
if field_name not in container_resources:
|
||||
# No limit/resource field.
|
||||
return float("inf")
|
||||
resources = container_resources[field_name]
|
||||
# Look for keys containing the resource_name. For example,
|
||||
# the key 'nvidia.com/gpu' contains the key 'gpu'.
|
||||
matching_keys = [key for key in resources if resource_name in key.lower()]
|
||||
if len(matching_keys) == 0:
|
||||
return float("inf")
|
||||
if len(matching_keys) > 1:
|
||||
# Should have only one match -- mostly relevant for gpu.
|
||||
raise ValueError(f"Multiple {resource_name} types not supported.")
|
||||
# E.g. 'nvidia.com/gpu' or 'cpu'.
|
||||
resource_key = matching_keys.pop()
|
||||
resource_quantity = resources[resource_key]
|
||||
return _parse_resource(resource_quantity)
|
||||
|
||||
|
||||
def _parse_resource(resource):
|
||||
resource_str = str(resource)
|
||||
if resource_str[-1] == "m":
|
||||
# For example, '500m' rounds up to 1.
|
||||
return math.ceil(int(resource_str[:-1]) / 1000)
|
||||
else:
|
||||
return int(resource_str)
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
import logging
|
||||
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS_LEGACY
|
||||
from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
_internal_kv_initialized
|
||||
"""This file provides legacy support for the old info string in order to
|
||||
ensure the dashboard's `api/cluster_status` does not break backwards
|
||||
compatibilty.
|
||||
"""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def legacy_log_info_string(autoscaler, nodes):
|
||||
tmp = "Cluster status: "
|
||||
tmp += info_string(autoscaler, nodes)
|
||||
tmp += "\n"
|
||||
tmp += autoscaler.load_metrics.info_string()
|
||||
tmp += "\n"
|
||||
tmp += autoscaler.resource_demand_scheduler.debug_string(
|
||||
nodes, autoscaler.pending_launches.breakdown(),
|
||||
autoscaler.load_metrics.get_resource_utilization())
|
||||
if _internal_kv_initialized():
|
||||
_internal_kv_put(DEBUG_AUTOSCALING_STATUS_LEGACY, tmp, overwrite=True)
|
||||
logger.debug(tmp)
|
||||
|
||||
|
||||
def info_string(autoscaler, nodes):
|
||||
suffix = ""
|
||||
if autoscaler.updaters:
|
||||
suffix += " ({} updating)".format(len(autoscaler.updaters))
|
||||
if autoscaler.num_failed_updates:
|
||||
suffix += " ({} failed to update)".format(
|
||||
len(autoscaler.num_failed_updates))
|
||||
|
||||
return "{} nodes{}".format(len(nodes), suffix)
|
||||
@@ -1,16 +1,26 @@
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
import logging
|
||||
import time
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import ray._private.services as services
|
||||
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES
|
||||
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES,\
|
||||
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
|
||||
from ray.autoscaler._private.util import add_resources, freq_of_dicts
|
||||
from ray.gcs_utils import PlacementGroupTableData
|
||||
from ray.autoscaler._private.resource_demand_scheduler import \
|
||||
NodeIP, ResourceDict
|
||||
from ray.core.generated.common_pb2 import PlacementStrategy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
LoadMetricsSummary = namedtuple("LoadMetricsSummary", [
|
||||
"head_ip", "usage", "resource_demand", "pg_demand", "request_demand",
|
||||
"node_types"
|
||||
])
|
||||
|
||||
|
||||
class LoadMetrics:
|
||||
"""Container for cluster load metrics.
|
||||
@@ -31,6 +41,7 @@ class LoadMetrics:
|
||||
self.waiting_bundles = []
|
||||
self.infeasible_bundles = []
|
||||
self.pending_placement_groups = []
|
||||
self.resource_requests = []
|
||||
|
||||
def update(self,
|
||||
ip: str,
|
||||
@@ -72,34 +83,37 @@ class LoadMetrics:
|
||||
|
||||
def mark_active(self, ip):
|
||||
assert ip is not None, "IP should be known at this time"
|
||||
logger.info("Node {} is newly setup, treating as active".format(ip))
|
||||
logger.debug("Node {} is newly setup, treating as active".format(ip))
|
||||
self.last_heartbeat_time_by_ip[ip] = time.time()
|
||||
|
||||
def is_active(self, ip):
|
||||
return ip in self.last_heartbeat_time_by_ip
|
||||
|
||||
def prune_active_ips(self, active_ips):
|
||||
active_ips = set(active_ips)
|
||||
active_ips.add(self.local_ip)
|
||||
|
||||
def prune(mapping):
|
||||
def prune(mapping, should_log):
|
||||
unwanted = set(mapping) - active_ips
|
||||
for unwanted_key in unwanted:
|
||||
# TODO (Alex): Change this back to info after #12138.
|
||||
logger.debug("LoadMetrics: "
|
||||
"Removed mapping: {} - {}".format(
|
||||
unwanted_key, mapping[unwanted_key]))
|
||||
if should_log:
|
||||
logger.info("LoadMetrics: "
|
||||
"Removed mapping: {} - {}".format(
|
||||
unwanted_key, mapping[unwanted_key]))
|
||||
del mapping[unwanted_key]
|
||||
if unwanted:
|
||||
if unwanted and should_log:
|
||||
# TODO (Alex): Change this back to info after #12138.
|
||||
logger.debug(
|
||||
logger.info(
|
||||
"LoadMetrics: "
|
||||
"Removed {} stale ip mappings: {} not in {}".format(
|
||||
len(unwanted), unwanted, active_ips))
|
||||
assert not (unwanted & set(mapping))
|
||||
|
||||
prune(self.last_used_time_by_ip)
|
||||
prune(self.static_resources_by_ip)
|
||||
prune(self.dynamic_resources_by_ip)
|
||||
prune(self.resource_load_by_ip)
|
||||
prune(self.last_heartbeat_time_by_ip)
|
||||
prune(self.last_used_time_by_ip, should_log=True)
|
||||
prune(self.static_resources_by_ip, should_log=False)
|
||||
prune(self.dynamic_resources_by_ip, should_log=False)
|
||||
prune(self.resource_load_by_ip, should_log=False)
|
||||
prune(self.last_heartbeat_time_by_ip, should_log=False)
|
||||
|
||||
def get_node_resources(self):
|
||||
"""Return a list of node resources (static resource sizes).
|
||||
@@ -155,12 +169,82 @@ class LoadMetrics:
|
||||
|
||||
return resources_used, resources_total
|
||||
|
||||
def get_resource_demand_vector(self):
|
||||
return self.waiting_bundles + self.infeasible_bundles
|
||||
def get_resource_demand_vector(self, clip=True):
|
||||
if clip:
|
||||
# Bound the total number of bundles to
|
||||
# 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE. This guarantees the resource
|
||||
# demand scheduler bin packing algorithm takes a reasonable amount
|
||||
# of time to run.
|
||||
return (
|
||||
self.
|
||||
waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] +
|
||||
self.
|
||||
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
|
||||
)
|
||||
else:
|
||||
return self.waiting_bundles + self.infeasible_bundles
|
||||
|
||||
def get_resource_requests(self):
|
||||
return self.resource_requests
|
||||
|
||||
def get_pending_placement_groups(self):
|
||||
return self.pending_placement_groups
|
||||
|
||||
def summary(self):
|
||||
available_resources = reduce(add_resources,
|
||||
self.dynamic_resources_by_ip.values()
|
||||
) if self.dynamic_resources_by_ip else {}
|
||||
total_resources = reduce(add_resources,
|
||||
self.static_resources_by_ip.values()
|
||||
) if self.static_resources_by_ip else {}
|
||||
usage_dict = {}
|
||||
for key in total_resources:
|
||||
total = total_resources[key]
|
||||
usage_dict[key] = (total - available_resources[key], total)
|
||||
|
||||
summarized_demand_vector = freq_of_dicts(
|
||||
self.get_resource_demand_vector(clip=False))
|
||||
summarized_resource_requests = freq_of_dicts(
|
||||
self.get_resource_requests())
|
||||
|
||||
def placement_group_serializer(pg):
|
||||
bundles = tuple(
|
||||
frozenset(bundle.unit_resources.items())
|
||||
for bundle in pg.bundles)
|
||||
return (bundles, pg.strategy)
|
||||
|
||||
def placement_group_deserializer(pg_tuple):
|
||||
# We marshal this as a dictionary so that we can easily json.dumps
|
||||
# it later.
|
||||
# TODO (Alex): Would there be a benefit to properly
|
||||
# marshalling this (into a protobuf)?
|
||||
bundles = list(map(dict, pg_tuple[0]))
|
||||
return {
|
||||
"bundles": freq_of_dicts(bundles),
|
||||
"strategy": PlacementStrategy.Name(pg_tuple[1])
|
||||
}
|
||||
|
||||
summarized_placement_groups = freq_of_dicts(
|
||||
self.get_pending_placement_groups(),
|
||||
serializer=placement_group_serializer,
|
||||
deserializer=placement_group_deserializer)
|
||||
nodes_summary = freq_of_dicts(self.static_resources_by_ip.values())
|
||||
|
||||
return LoadMetricsSummary(
|
||||
head_ip=self.local_ip,
|
||||
usage=usage_dict,
|
||||
resource_demand=summarized_demand_vector,
|
||||
pg_demand=summarized_placement_groups,
|
||||
request_demand=summarized_resource_requests,
|
||||
node_types=nodes_summary)
|
||||
|
||||
def set_resource_requests(self, requested_resources):
|
||||
if requested_resources is not None:
|
||||
assert isinstance(requested_resources, list), requested_resources
|
||||
self.resource_requests = [
|
||||
request for request in requested_resources if len(request) > 0
|
||||
]
|
||||
|
||||
def info_string(self):
|
||||
return " - " + "\n - ".join(
|
||||
["{}: {}".format(k, v) for k, v in sorted(self._info().items())])
|
||||
|
||||
@@ -47,16 +47,19 @@ class ResourceDemandScheduler:
|
||||
provider: NodeProvider,
|
||||
node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
max_workers: int,
|
||||
head_node_type: NodeType,
|
||||
upscaling_speed: float = 1) -> None:
|
||||
self.provider = provider
|
||||
self.node_types = copy.deepcopy(node_types)
|
||||
self.max_workers = max_workers
|
||||
self.head_node_type = head_node_type
|
||||
self.upscaling_speed = upscaling_speed
|
||||
|
||||
def reset_config(self,
|
||||
provider: NodeProvider,
|
||||
node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
max_workers: int,
|
||||
head_node_type: NodeType,
|
||||
upscaling_speed: float = 1) -> None:
|
||||
"""Updates the class state variables.
|
||||
|
||||
@@ -89,6 +92,7 @@ class ResourceDemandScheduler:
|
||||
self.provider = provider
|
||||
self.node_types = copy.deepcopy(final_node_types)
|
||||
self.max_workers = max_workers
|
||||
self.head_node_type = head_node_type
|
||||
self.upscaling_speed = upscaling_speed
|
||||
|
||||
def is_legacy_yaml(self,
|
||||
@@ -145,18 +149,18 @@ class ResourceDemandScheduler:
|
||||
node_resources, node_type_counts = self.calculate_node_resources(
|
||||
nodes, launching_nodes, unused_resources_by_ip)
|
||||
|
||||
logger.info("Cluster resources: {}".format(node_resources))
|
||||
logger.info("Node counts: {}".format(node_type_counts))
|
||||
logger.debug("Cluster resources: {}".format(node_resources))
|
||||
logger.debug("Node counts: {}".format(node_type_counts))
|
||||
# Step 2: add nodes to add to satisfy min_workers for each type
|
||||
(node_resources,
|
||||
node_type_counts,
|
||||
adjusted_min_workers) = \
|
||||
_add_min_workers_nodes(
|
||||
node_resources, node_type_counts, self.node_types,
|
||||
self.max_workers, ensure_min_cluster_size)
|
||||
self.max_workers, self.head_node_type, ensure_min_cluster_size)
|
||||
|
||||
# Step 3: add nodes for strict spread groups
|
||||
logger.info(f"Placement group demands: {pending_placement_groups}")
|
||||
logger.debug(f"Placement group demands: {pending_placement_groups}")
|
||||
placement_group_demand_vector, strict_spreads = \
|
||||
placement_groups_to_resource_demands(pending_placement_groups)
|
||||
resource_demands.extend(placement_group_demand_vector)
|
||||
@@ -183,12 +187,13 @@ class ResourceDemandScheduler:
|
||||
# groups
|
||||
unfulfilled, _ = get_bin_pack_residual(node_resources,
|
||||
resource_demands)
|
||||
logger.info("Resource demands: {}".format(resource_demands))
|
||||
logger.info("Unfulfilled demands: {}".format(unfulfilled))
|
||||
logger.debug("Resource demands: {}".format(resource_demands))
|
||||
logger.debug("Unfulfilled demands: {}".format(unfulfilled))
|
||||
# Add 1 to account for the head node.
|
||||
max_to_add = self.max_workers + 1 - sum(node_type_counts.values())
|
||||
nodes_to_add_based_on_demand = get_nodes_for(
|
||||
self.node_types, node_type_counts, max_to_add, unfulfilled)
|
||||
self.node_types, node_type_counts, self.head_node_type, max_to_add,
|
||||
unfulfilled)
|
||||
# Merge nodes to add based on demand and nodes to add based on
|
||||
# min_workers constraint. We add them because nodes to add based on
|
||||
# demand was calculated after the min_workers constraint was respected.
|
||||
@@ -206,7 +211,7 @@ class ResourceDemandScheduler:
|
||||
total_nodes_to_add, unused_resources_by_ip.keys(), nodes,
|
||||
launching_nodes, adjusted_min_workers)
|
||||
|
||||
logger.info("Node requests: {}".format(total_nodes_to_add))
|
||||
logger.debug("Node requests: {}".format(total_nodes_to_add))
|
||||
return total_nodes_to_add
|
||||
|
||||
def _legacy_worker_node_to_launch(
|
||||
@@ -443,6 +448,7 @@ class ResourceDemandScheduler:
|
||||
to_launch = get_nodes_for(
|
||||
self.node_types,
|
||||
node_type_counts,
|
||||
self.head_node_type,
|
||||
max_to_add,
|
||||
unfulfilled,
|
||||
strict_spread=True)
|
||||
@@ -490,7 +496,7 @@ def _add_min_workers_nodes(
|
||||
node_resources: List[ResourceDict],
|
||||
node_type_counts: Dict[NodeType, int],
|
||||
node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int,
|
||||
ensure_min_cluster_size: List[ResourceDict]
|
||||
head_node_type: NodeType, ensure_min_cluster_size: List[ResourceDict]
|
||||
) -> (List[ResourceDict], Dict[NodeType, int], Dict[NodeType, int]):
|
||||
"""Updates resource demands to respect the min_workers and
|
||||
request_resources() constraints.
|
||||
@@ -515,6 +521,9 @@ def _add_min_workers_nodes(
|
||||
existing = node_type_counts.get(node_type, 0)
|
||||
target = min(
|
||||
config.get("min_workers", 0), config.get("max_workers", 0))
|
||||
if node_type == head_node_type:
|
||||
# Add 1 to account for head node.
|
||||
target = target + 1
|
||||
if existing < target:
|
||||
total_nodes_to_add_dict[node_type] = target - existing
|
||||
node_type_counts[node_type] = target
|
||||
@@ -537,7 +546,7 @@ def _add_min_workers_nodes(
|
||||
max_node_resources, ensure_min_cluster_size)
|
||||
# Get the nodes to meet the unfulfilled.
|
||||
nodes_to_add_request_resources = get_nodes_for(
|
||||
node_types, node_type_counts, max_to_add,
|
||||
node_types, node_type_counts, head_node_type, max_to_add,
|
||||
resource_requests_unfulfilled)
|
||||
# Update the resources, counts and total nodes to add.
|
||||
for node_type in nodes_to_add_request_resources:
|
||||
@@ -558,6 +567,7 @@ def _add_min_workers_nodes(
|
||||
|
||||
def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
existing_nodes: Dict[NodeType, int],
|
||||
head_node_type: NodeType,
|
||||
max_to_add: int,
|
||||
resources: List[ResourceDict],
|
||||
strict_spread: bool = False) -> Dict[NodeType, int]:
|
||||
@@ -581,9 +591,13 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
while resources and sum(nodes_to_add.values()) < max_to_add:
|
||||
utilization_scores = []
|
||||
for node_type in node_types:
|
||||
max_workers_of_node_type = node_types[node_type].get(
|
||||
"max_workers", 0)
|
||||
if head_node_type == node_type:
|
||||
# Add 1 to account for head node.
|
||||
max_workers_of_node_type = max_workers_of_node_type + 1
|
||||
if (existing_nodes.get(node_type, 0) + nodes_to_add.get(
|
||||
node_type, 0) >= node_types[node_type].get(
|
||||
"max_workers", 0)):
|
||||
node_type, 0) >= max_workers_of_node_type):
|
||||
continue
|
||||
node_resources = node_types[node_type]["resources"]
|
||||
if strict_spread:
|
||||
@@ -601,8 +615,14 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
|
||||
# starts up because placement groups are scheduled via custom
|
||||
# resources. This will behave properly with the current utilization
|
||||
# score heuristic, but it's a little dangerous and misleading.
|
||||
logger.info(
|
||||
"No feasible node type to add for {}".format(resources))
|
||||
logger.warning(
|
||||
f"The autoscaler could not find a node type to satisfy the"
|
||||
f"request: {resources}. If this request is related to "
|
||||
f"placement groups the resource request will resolve itself, "
|
||||
f"otherwise please specify a node type with the necessary "
|
||||
f"resource "
|
||||
f"https://docs.ray.io/en/master/cluster/autoscaling.html#multiple-node-type-autoscaling." # noqa: E501
|
||||
)
|
||||
break
|
||||
|
||||
utilization_scores = sorted(utilization_scores, reverse=True)
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import collections
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import hashlib
|
||||
import json
|
||||
import jsonschema
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import ray
|
||||
import ray.ray_constants
|
||||
import ray._private.services as services
|
||||
from ray.autoscaler._private.providers import _get_default_config
|
||||
from ray.autoscaler._private.docker import validate_docker_config
|
||||
@@ -20,6 +22,7 @@ RAY_SCHEMA_PATH = os.path.join(
|
||||
# Internal kv keys for storing debug status.
|
||||
DEBUG_AUTOSCALING_ERROR = "__autoscaling_error"
|
||||
DEBUG_AUTOSCALING_STATUS = "__autoscaling_status"
|
||||
DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy"
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -246,6 +249,47 @@ def hash_runtime_conf(file_mounts,
|
||||
return (_hash_cache[conf_str], file_mounts_contents_hash)
|
||||
|
||||
|
||||
def add_resources(dict1: Dict[str, float],
|
||||
dict2: Dict[str, float]) -> Dict[str, float]:
|
||||
"""Add the values in two dictionaries.
|
||||
|
||||
Returns:
|
||||
dict: A new dictionary (inputs remain unmodified).
|
||||
"""
|
||||
new_dict = dict1.copy()
|
||||
for k, v in dict2.items():
|
||||
new_dict[k] = v + new_dict.get(k, 0)
|
||||
return new_dict
|
||||
|
||||
|
||||
def freq_of_dicts(dicts: List[Dict],
|
||||
serializer=lambda d: frozenset(d.items()),
|
||||
deserializer=dict):
|
||||
"""Count a list of dictionaries (or unhashable types).
|
||||
|
||||
This is somewhat annoying because mutable data structures aren't hashable,
|
||||
and set/dict keys must be hashable.
|
||||
|
||||
Args:
|
||||
dicts (List[D]): A list of dictionaries to be counted.
|
||||
serializer (D -> S): A custom serailization function. The output type S
|
||||
must be hashable. The default serializer converts a dictionary into
|
||||
a frozenset of KV pairs.
|
||||
deserializer (S -> U): A custom deserialization function. See the
|
||||
serializer for information about type S. For dictionaries U := D.
|
||||
|
||||
Returns:
|
||||
List[Tuple[U, int]]: Returns a list of tuples. Each entry in the list
|
||||
is a tuple containing a unique entry from `dicts` and its
|
||||
corresponding frequency count.
|
||||
"""
|
||||
freqs = collections.Counter(map(lambda d: serializer(d), dicts))
|
||||
as_list = []
|
||||
for as_set, count in freqs.items():
|
||||
as_list.append((deserializer(as_set), count))
|
||||
return as_list
|
||||
|
||||
|
||||
def add_prefix(info_string, prefix):
|
||||
"""Prefixes each line of info_string, except the first, by prefix."""
|
||||
lines = info_string.split("\n")
|
||||
@@ -255,3 +299,132 @@ def add_prefix(info_string, prefix):
|
||||
prefixed_lines.append(prefixed_line)
|
||||
prefixed_info_string = "\n".join(prefixed_lines)
|
||||
return prefixed_info_string
|
||||
|
||||
|
||||
def format_pg(pg):
|
||||
strategy = pg["strategy"]
|
||||
bundles = pg["bundles"]
|
||||
shape_strs = []
|
||||
for bundle, count in bundles:
|
||||
shape_strs.append(f"{bundle} * {count}")
|
||||
bundles_str = ", ".join(shape_strs)
|
||||
return f"{bundles_str} ({strategy})"
|
||||
|
||||
|
||||
def get_usage_report(lm_summary):
|
||||
usage_lines = []
|
||||
for resource, (used, total) in lm_summary.usage.items():
|
||||
line = f" {used}/{total} {resource}"
|
||||
if resource in ["memory", "object_store_memory"]:
|
||||
to_GiB = ray.ray_constants.MEMORY_RESOURCE_UNIT_BYTES / 2**30
|
||||
used *= to_GiB
|
||||
total *= to_GiB
|
||||
line = f" {used:.2f}/{total:.3f} GiB {resource}"
|
||||
usage_lines.append(line)
|
||||
usage_report = "\n".join(usage_lines)
|
||||
return usage_report
|
||||
|
||||
|
||||
def get_demand_report(lm_summary):
|
||||
demand_lines = []
|
||||
for bundle, count in lm_summary.resource_demand:
|
||||
line = f" {bundle}: {count}+ pending tasks/actors"
|
||||
demand_lines.append(line)
|
||||
for entry in lm_summary.pg_demand:
|
||||
pg, count = entry
|
||||
pg_str = format_pg(pg)
|
||||
line = f" {pg_str}: {count}+ pending placement groups"
|
||||
demand_lines.append(line)
|
||||
for bundle, count in lm_summary.request_demand:
|
||||
line = f" {bundle}: {count}+ from request_resources()"
|
||||
demand_lines.append(line)
|
||||
if len(demand_lines) > 0:
|
||||
demand_report = "\n".join(demand_lines)
|
||||
else:
|
||||
demand_report = " (no resource demands)"
|
||||
return demand_report
|
||||
|
||||
|
||||
def format_info_string(lm_summary, autoscaler_summary, time=None):
|
||||
if time is None:
|
||||
time = datetime.now()
|
||||
header = "=" * 8 + f" Autoscaler status: {time} " + "=" * 8
|
||||
separator = "-" * len(header)
|
||||
available_node_report_lines = []
|
||||
for node_type, count in autoscaler_summary.active_nodes.items():
|
||||
line = f" {count} {node_type}"
|
||||
available_node_report_lines.append(line)
|
||||
available_node_report = "\n".join(available_node_report_lines)
|
||||
|
||||
pending_lines = []
|
||||
for node_type, count in autoscaler_summary.pending_launches.items():
|
||||
line = f" {node_type}, {count} launching"
|
||||
pending_lines.append(line)
|
||||
for ip, node_type in autoscaler_summary.pending_nodes:
|
||||
line = f" {ip}: {node_type}, setting up"
|
||||
pending_lines.append(line)
|
||||
if pending_lines:
|
||||
pending_report = "\n".join(pending_lines)
|
||||
else:
|
||||
pending_report = " (no pending nodes)"
|
||||
|
||||
failure_lines = []
|
||||
for ip, node_type in autoscaler_summary.failed_nodes:
|
||||
line = f" {ip}: {node_type}"
|
||||
failure_report = "Recent failures:\n"
|
||||
if failure_lines:
|
||||
failure_report += "\n".join(failure_lines)
|
||||
else:
|
||||
failure_report += " (no failures)"
|
||||
|
||||
usage_report = get_usage_report(lm_summary)
|
||||
demand_report = get_demand_report(lm_summary)
|
||||
|
||||
formatted_output = f"""{header}
|
||||
Node status
|
||||
{separator}
|
||||
Healthy:
|
||||
{available_node_report}
|
||||
Pending:
|
||||
{pending_report}
|
||||
{failure_report}
|
||||
|
||||
Resources
|
||||
{separator}
|
||||
|
||||
Usage:
|
||||
{usage_report}
|
||||
|
||||
Demands:
|
||||
{demand_report}"""
|
||||
return formatted_output
|
||||
|
||||
|
||||
def format_info_string_no_node_types(lm_summary, time=None):
|
||||
if time is None:
|
||||
time = datetime.now()
|
||||
header = "=" * 8 + f" Cluster status: {time} " + "=" * 8
|
||||
separator = "-" * len(header)
|
||||
|
||||
node_lines = []
|
||||
for node_type, count in lm_summary.node_types:
|
||||
line = f" {count} node(s) with resources: {node_type}"
|
||||
node_lines.append(line)
|
||||
node_report = "\n".join(node_lines)
|
||||
|
||||
usage_report = get_usage_report(lm_summary)
|
||||
demand_report = get_demand_report(lm_summary)
|
||||
|
||||
formatted_output = f"""{header}
|
||||
Node status
|
||||
{separator}
|
||||
{node_report}
|
||||
|
||||
Resources
|
||||
{separator}
|
||||
Usage:
|
||||
{usage_report}
|
||||
|
||||
Demands:
|
||||
{demand_report}"""
|
||||
return formatted_output
|
||||
|
||||
@@ -250,9 +250,11 @@ worker_nodes:
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
# "~/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "~/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
# Note that the container images in this example have a non-root user.
|
||||
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
|
||||
|
||||
# Files or directories to copy from the head node to the worker nodes. The format is a
|
||||
# list of paths. The same path on the head node will be copied to the worker node.
|
||||
|
||||
@@ -250,9 +250,11 @@ worker_nodes:
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
# "~/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "~/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
# Note that the container images in this example have a non-root user.
|
||||
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
|
||||
|
||||
# Files or directories to copy from the head node to the worker nodes. The format is a
|
||||
# list of paths. The same path on the head node will be copied to the worker node.
|
||||
|
||||
@@ -286,9 +286,11 @@ worker_nodes:
|
||||
# Files or directories to copy to the head and worker nodes. The format is a
|
||||
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
|
||||
file_mounts: {
|
||||
# "/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
# "~/path1/on/remote/machine": "/path1/on/local/machine",
|
||||
# "~/path2/on/remote/machine": "/path2/on/local/machine",
|
||||
}
|
||||
# Note that the container images in this example have a non-root user.
|
||||
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
|
||||
|
||||
# List of commands that will be run before `setup_commands`. If docker is
|
||||
# enabled, these commands will run outside the container and before docker
|
||||
|
||||
@@ -142,7 +142,8 @@ class WorkerCrashedError(RayError):
|
||||
"""Indicates that the worker died unexpectedly while executing a task."""
|
||||
|
||||
def __str__(self):
|
||||
return "The worker died unexpectedly while executing this task."
|
||||
return ("The worker died unexpectedly while executing this task. "
|
||||
"Check python-core-worker-*.log files for more information.")
|
||||
|
||||
|
||||
class RayActorError(RayError):
|
||||
@@ -153,7 +154,8 @@ class RayActorError(RayError):
|
||||
"""
|
||||
|
||||
def __str__(self):
|
||||
return "The actor died unexpectedly before finishing this task."
|
||||
return ("The actor died unexpectedly before finishing this task. "
|
||||
"Check python-core-worker-*.log files for more information.")
|
||||
|
||||
|
||||
class RaySystemError(RayError):
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from .dynamic_resources import set_resource
|
||||
from .object_spilling import force_spill_objects
|
||||
__all__ = [
|
||||
"set_resource",
|
||||
"force_spill_objects",
|
||||
]
|
||||
|
||||
@@ -1,117 +1,100 @@
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from typing import Optional, List, Tuple
|
||||
from contextlib import contextmanager
|
||||
from typing import List, Tuple
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# About these global variables: Ray 1.0 uses exported module functions to
|
||||
# provide its API, and we need to match that. However, we want different
|
||||
# behaviors depending on where, exactly, in the client stack this is running.
|
||||
#
|
||||
# The reason for these differences depends on what's being pickled and passed
|
||||
# to functions, or functions inside functions. So there are three cases to care
|
||||
# about
|
||||
#
|
||||
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
|
||||
#
|
||||
# * _client_api should be set if we're inside the client
|
||||
# * _server_api should be set if we're inside the clientserver
|
||||
# * Both will be set if we're running both (as in a test)
|
||||
# * Neither should be set if we're inside the raylet (but we still need to shim
|
||||
# from the client API surface to the Ray API)
|
||||
#
|
||||
# The job of RayAPIStub (below) delegates to the appropriate one of these
|
||||
# depending on what's set or not. Then, all users importing the ray object
|
||||
# from this package get the stub which routes them to the appropriate APIImpl.
|
||||
_client_api: Optional[APIImpl] = None
|
||||
_server_api: Optional[APIImpl] = None
|
||||
|
||||
# The reason for _is_server is a hack around the above comment while running
|
||||
# tests. If we have both a client and a server trying to control these static
|
||||
# variables then we need a way to decide which to use. In this case, both
|
||||
# _client_api and _server_api are set.
|
||||
# This boolean flips between the two
|
||||
_is_server: bool = False
|
||||
|
||||
|
||||
@contextmanager
|
||||
def stash_api_for_tests(in_test: bool):
|
||||
global _is_server
|
||||
is_server = _is_server
|
||||
if in_test:
|
||||
_is_server = True
|
||||
yield _server_api
|
||||
if in_test:
|
||||
_is_server = is_server
|
||||
|
||||
|
||||
def _set_client_api(val: Optional[APIImpl]):
|
||||
global _client_api
|
||||
global _is_server
|
||||
if _client_api is not None:
|
||||
raise Exception("Trying to set more than one client API")
|
||||
_client_api = val
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _set_server_api(val: Optional[APIImpl]):
|
||||
global _server_api
|
||||
global _is_server
|
||||
if _server_api is not None:
|
||||
raise Exception("Trying to set more than one server API")
|
||||
_server_api = val
|
||||
_is_server = True
|
||||
|
||||
|
||||
def reset_api():
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
_client_api = None
|
||||
_server_api = None
|
||||
_is_server = False
|
||||
|
||||
|
||||
def _get_client_api() -> APIImpl:
|
||||
global _client_api
|
||||
global _server_api
|
||||
global _is_server
|
||||
api = None
|
||||
if _is_server:
|
||||
api = _server_api
|
||||
else:
|
||||
api = _client_api
|
||||
if api is None:
|
||||
# We're inside a raylet worker
|
||||
from ray.experimental.client.server.core_ray_api import CoreRayAPI
|
||||
return CoreRayAPI()
|
||||
return api
|
||||
|
||||
|
||||
class RayAPIStub:
|
||||
"""This class stands in as the replacement API for the `import ray` module.
|
||||
|
||||
Much like the ray module, this mostly delegates the work to the
|
||||
_client_worker. As parts of the ray API are covered, they are piped through
|
||||
here or on the client worker API.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
from ray.experimental.client.api import ClientAPI
|
||||
self.api = ClientAPI()
|
||||
self.client_worker = None
|
||||
self._server = None
|
||||
self._connected_with_init = False
|
||||
self._inside_client_test = False
|
||||
|
||||
def connect(self,
|
||||
conn_str: str,
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
metadata: List[Tuple[str, str]] = None) -> None:
|
||||
"""Connect the Ray Client to a server.
|
||||
|
||||
Args:
|
||||
conn_str: Connection string, in the form "[host]:port"
|
||||
secure: Whether to use a TLS secured gRPC channel
|
||||
metadata: gRPC metadata to send on connect
|
||||
"""
|
||||
# Delay imports until connect to avoid circular imports.
|
||||
from ray.experimental.client.worker import Worker
|
||||
_client_worker = Worker(
|
||||
conn_str, secure=secure, metadata=metadata, stub=stub)
|
||||
_set_client_api(ClientAPI(_client_worker))
|
||||
import ray._private.client_mode_hook
|
||||
if self.client_worker is not None:
|
||||
if self._connected_with_init:
|
||||
return
|
||||
raise Exception(
|
||||
"ray.connect() called, but ray client is already connected")
|
||||
if not self._inside_client_test:
|
||||
# If we're calling a client connect specifically and we're not
|
||||
# currently in client mode, ensure we are.
|
||||
ray._private.client_mode_hook._explicitly_enable_client_mode()
|
||||
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
|
||||
self.api.worker = self.client_worker
|
||||
|
||||
def disconnect(self):
|
||||
global _client_api
|
||||
if _client_api is not None:
|
||||
_client_api.close()
|
||||
_client_api = None
|
||||
"""Disconnect the Ray Client.
|
||||
"""
|
||||
if self.client_worker is not None:
|
||||
self.client_worker.close()
|
||||
self.client_worker = None
|
||||
|
||||
# remote can be called outside of a connection, which is why it
|
||||
# exists on the same API layer as connect() itself.
|
||||
def remote(self, *args, **kwargs):
|
||||
"""remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
return self.api.remote(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
global _get_client_api
|
||||
api = _get_client_api()
|
||||
return getattr(api, key)
|
||||
if not self.is_connected():
|
||||
raise Exception("Ray Client is not connected. "
|
||||
"Please connect by calling `ray.connect`.")
|
||||
return getattr(self.api, key)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self.api is not None
|
||||
|
||||
def init(self, *args, **kwargs):
|
||||
if self._server is not None:
|
||||
raise Exception("Trying to start two instances of ray via client")
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
self._server, address_info = ray_client_server.init_and_serve(
|
||||
"localhost:50051", *args, **kwargs)
|
||||
self.connect("localhost:50051")
|
||||
self._connected_with_init = True
|
||||
return address_info
|
||||
|
||||
def shutdown(self, _exiting_interpreter=False):
|
||||
self.disconnect()
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
if self._server is None:
|
||||
return
|
||||
ray_client_server.shutdown_with_server(self._server,
|
||||
_exiting_interpreter)
|
||||
self._server = None
|
||||
|
||||
|
||||
ray = RayAPIStub()
|
||||
|
||||
@@ -1,74 +1,51 @@
|
||||
# This file defines an interface and client-side API stub
|
||||
# for referring either to the core Ray API or the same interface
|
||||
# from the Ray client.
|
||||
#
|
||||
# In tandem with __init__.py, we want to expose an API that's
|
||||
# close to `python/ray/__init__.py` but with more than one implementation.
|
||||
# The stubs in __init__ should call into a well-defined interface.
|
||||
# Only the core Ray API implementation should actually `import ray`
|
||||
# (and thus import all the raylet worker C bindings and such).
|
||||
# But to make sure that we're matching these calls, we define this API.
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Union, Optional
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
"""This file defines the interface between the ray client worker
|
||||
and the overall ray module API.
|
||||
"""
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientStub
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray._raylet import ObjectRef
|
||||
|
||||
# Use the imports for type checking. This is a python 3.6 limitation.
|
||||
# See https://www.python.org/dev/peps/pep-0563/
|
||||
PutType = Union[ClientObjectRef, ObjectRef]
|
||||
|
||||
|
||||
class APIImpl(ABC):
|
||||
"""
|
||||
APIImpl is the interface to implement for whichever version of the core
|
||||
Ray API that needs abstracting when run in client mode.
|
||||
class ClientAPI:
|
||||
"""The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
"""
|
||||
get is the hook stub passed on to replace `ray.get`
|
||||
def __init__(self, worker=None):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, vals, *, timeout=None):
|
||||
"""get is the hook stub passed on to replace `ray.get`
|
||||
|
||||
Args:
|
||||
vals: [Client]ObjectRef or list of these refs to retrieve.
|
||||
timeout: Optional timeout in milliseconds
|
||||
"""
|
||||
pass
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
@abstractmethod
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
|
||||
"""
|
||||
put is the hook stub passed on to replace `ray.put`
|
||||
def put(self, *args, **kwargs):
|
||||
"""put is the hook stub passed on to replace `ray.put`
|
||||
|
||||
Args:
|
||||
vals: The value or list of values to `put`.
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def wait(self, *args, **kwargs):
|
||||
"""
|
||||
wait is the hook stub passed on to replace `ray.wait`
|
||||
"""wait is the hook stub passed on to replace `ray.wait`
|
||||
|
||||
Args:
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def remote(self, *args, **kwargs):
|
||||
"""
|
||||
remote is the hook stub passed on to replace `ray.remote`.
|
||||
"""remote is the hook stub passed on to replace `ray.remote`.
|
||||
|
||||
This sets up remote functions or actors, as the decorator,
|
||||
but does not execute them.
|
||||
@@ -77,12 +54,24 @@ class APIImpl(ABC):
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
# Delayed import to avoid a cyclic import
|
||||
from ray.experimental.client.common import remote_decorator
|
||||
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
|
||||
# This is the case where the decorator is just @ray.remote.
|
||||
return remote_decorator(options=None)(args[0])
|
||||
error_string = ("The @ray.remote decorator must be applied either "
|
||||
"with no arguments and no parentheses, for example "
|
||||
"'@ray.remote', or it must be applied using some of "
|
||||
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
|
||||
"'memory', 'object_store_memory', 'resources', "
|
||||
"'max_calls', or 'max_restarts', like "
|
||||
"'@ray.remote(num_returns=2, "
|
||||
"resources={\"CustomResource\": 1})'.")
|
||||
assert len(args) == 0 and len(kwargs) > 0, error_string
|
||||
return remote_decorator(options=kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
"""
|
||||
call_remote is called by stub objects to execute them remotely.
|
||||
"""call_remote is called by stub objects to execute them remotely.
|
||||
|
||||
This is used by stub objects in situations where they're called
|
||||
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
|
||||
@@ -95,31 +84,57 @@ class APIImpl(ABC):
|
||||
args: opaque arguments
|
||||
kwargs: opaque keyword arguments
|
||||
"""
|
||||
pass
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
@abstractmethod
|
||||
def close(self) -> None:
|
||||
def call_release(self, id: bytes) -> None:
|
||||
"""Attempts to release an object reference.
|
||||
|
||||
When client references are destructed, they release their reference,
|
||||
which can opportunistically send a notification through the datachannel
|
||||
to release the reference being held for that object on the server.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to release on the server side.
|
||||
"""
|
||||
close cleans up an API connection by closing any channels or
|
||||
return self.worker.call_release(id)
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
"""Attempts to retain a client object reference.
|
||||
|
||||
Increments the reference count on the client side, to prevent
|
||||
the client worker from attempting to release the server reference.
|
||||
|
||||
Args:
|
||||
id: The id of the reference to retain on the client side.
|
||||
"""
|
||||
return self.worker.call_retain(id)
|
||||
|
||||
def close(self) -> None:
|
||||
"""close cleans up an API connection by closing any channels or
|
||||
shutting down any servers gracefully.
|
||||
"""
|
||||
pass
|
||||
return self.worker.close()
|
||||
|
||||
@abstractmethod
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
def get_actor(self, name: str) -> "ClientActorHandle":
|
||||
"""Returns a handle to an actor by name.
|
||||
|
||||
Args:
|
||||
name: The name passed to this actor by
|
||||
Actor.options(name="name").remote()
|
||||
"""
|
||||
kill forcibly stops an actor running in the cluster
|
||||
return self.worker.get_actor(name)
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
"""kill forcibly stops an actor running in the cluster
|
||||
|
||||
Args:
|
||||
no_restart: Whether this actor should be restarted if it's a
|
||||
restartable actor.
|
||||
"""
|
||||
pass
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
@abstractmethod
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
"""
|
||||
Cancels a task on the cluster.
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
"""Cancels a task on the cluster.
|
||||
|
||||
If the specified task is pending execution, it will not be executed. If
|
||||
the task is currently executing, the behavior depends on the ``force``
|
||||
@@ -136,46 +151,11 @@ class APIImpl(ABC):
|
||||
recursive (boolean): Whether to try to cancel tasks submitted by
|
||||
the task specified.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class ClientAPI(APIImpl):
|
||||
"""
|
||||
The Client-side methods corresponding to the ray API. Delegates
|
||||
to the Client Worker that contains the connection to the ClientServer.
|
||||
"""
|
||||
|
||||
def __init__(self, worker):
|
||||
self.worker = worker
|
||||
|
||||
def get(self, vals, *, timeout=None):
|
||||
return self.worker.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, *args, **kwargs):
|
||||
return self.worker.put(*args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return self.worker.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return self.worker.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: "ClientStub", *args, **kwargs):
|
||||
return self.worker.call_remote(instance, *args, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
return self.worker.close()
|
||||
|
||||
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
|
||||
return self.worker.terminate_actor(actor, no_restart)
|
||||
|
||||
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
|
||||
return self.worker.terminate_task(obj, force, recursive)
|
||||
|
||||
# Various metadata methods for the client that are defined in the protocol.
|
||||
def is_initialized(self) -> bool:
|
||||
""" True if our client is connected, and if the server is initialized.
|
||||
|
||||
"""True if our client is connected, and if the server is initialized.
|
||||
Returns:
|
||||
A boolean determining if the client is connected and
|
||||
server initialized.
|
||||
@@ -188,6 +168,8 @@ class ClientAPI(APIImpl):
|
||||
Returns:
|
||||
Information about the Ray clients in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.NODES)
|
||||
|
||||
@@ -201,6 +183,8 @@ class ClientAPI(APIImpl):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
|
||||
|
||||
@@ -216,6 +200,8 @@ class ClientAPI(APIImpl):
|
||||
A dictionary mapping resource name to the total quantity of that
|
||||
resource in the cluster.
|
||||
"""
|
||||
# This should be imported here, otherwise, it will error doc build.
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
return self.worker.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
|
||||
|
||||
|
||||
@@ -0,0 +1,182 @@
|
||||
"""Implements the client side of the client/server pickling protocol.
|
||||
|
||||
All ray client client/server data transfer happens through this pickling
|
||||
protocol. The model is as follows:
|
||||
|
||||
* All Client objects (eg ClientObjectRef) always live on the client and
|
||||
are never represented in the server
|
||||
* All Ray objects (eg, ray.ObjectRef) always live on the server and are
|
||||
never returned to the client
|
||||
* In order to translate between these two references, PickleStub tuples
|
||||
are generated as persistent ids in the data blobs during the pickling
|
||||
and unpickling of these objects.
|
||||
|
||||
The PickleStubs have just enough information to find or generate their
|
||||
associated partner object on either side.
|
||||
|
||||
This also has the advantage of avoiding predefined pickle behavior for ray
|
||||
objects, which may include ray internal reference counting.
|
||||
|
||||
ClientPickler dumps things from the client into the appropriate stubs
|
||||
ServerUnpickler loads stubs from the server into their client counterparts.
|
||||
"""
|
||||
|
||||
import cloudpickle
|
||||
import io
|
||||
import sys
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
from ray.experimental.client import RayAPIStub
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientRemoteMethod
|
||||
from ray.experimental.client.common import OptionWrapper
|
||||
from ray.experimental.client.common import SelfReferenceSentinel
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
import pickle5 as pickle # noqa: F401
|
||||
except ImportError:
|
||||
import pickle # noqa: F401
|
||||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
# NOTE(barakmich): These PickleStubs are really close to
|
||||
# the data for an exectuion, with no arguments. Combine the two?
|
||||
PickleStub = NamedTuple("PickleStub",
|
||||
[("type", str), ("client_id", str), ("ref_id", bytes),
|
||||
("name", Optional[str]),
|
||||
("baseline_options", Optional[Dict])])
|
||||
|
||||
|
||||
class ClientPickler(cloudpickle.CloudPickler):
|
||||
def __init__(self, client_id, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client_id = client_id
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, RayAPIStub):
|
||||
return PickleStub(
|
||||
type="Ray",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientObjectRef):
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientActorHandle):
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteFunc):
|
||||
# TODO(barakmich): This is going to have trouble with mutually
|
||||
# recursive functions that haven't, as yet, been executed. It's
|
||||
# relatively doable (keep track of intermediate refs in progress
|
||||
# with ensure_ref and return appropriately) But punting for now.
|
||||
if obj._ref is None:
|
||||
obj._ensure_ref()
|
||||
if type(obj._ref) == SelfReferenceSentinel:
|
||||
return PickleStub(
|
||||
type="RemoteFuncSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteFunc",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
baseline_options=obj._options,
|
||||
)
|
||||
elif isinstance(obj, ClientActorClass):
|
||||
# TODO(barakmich): Mutual recursion, as above.
|
||||
if obj._ref is None:
|
||||
obj._ensure_ref()
|
||||
if type(obj._ref) == SelfReferenceSentinel:
|
||||
return PickleStub(
|
||||
type="RemoteActorSelfReference",
|
||||
client_id=self.client_id,
|
||||
ref_id=b"",
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return PickleStub(
|
||||
type="RemoteActor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._ref.id,
|
||||
name=None,
|
||||
baseline_options=obj._options,
|
||||
)
|
||||
elif isinstance(obj, ClientRemoteMethod):
|
||||
return PickleStub(
|
||||
type="RemoteMethod",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj.actor_handle.actor_ref.id,
|
||||
name=obj.method_name,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, OptionWrapper):
|
||||
raise NotImplementedError(
|
||||
"Sending a partial option is unimplemented")
|
||||
return None
|
||||
|
||||
|
||||
class ServerUnpickler(pickle.Unpickler):
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Object":
|
||||
return ClientObjectRef(id=pid.ref_id)
|
||||
elif pid.type == "Actor":
|
||||
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
|
||||
else:
|
||||
raise NotImplementedError("Being passed back an unknown stub")
|
||||
|
||||
|
||||
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
|
||||
with disable_client_hook():
|
||||
with io.BytesIO() as file:
|
||||
cp = ClientPickler(client_id, file, protocol=protocol)
|
||||
cp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_server(data: bytes,
|
||||
*,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ServerUnpickler(
|
||||
file, fix_imports=fix_imports, encoding=encoding,
|
||||
errors=errors).load()
|
||||
|
||||
|
||||
def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
|
||||
out = ray_client_pb2.Arg()
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = dumps_from_client(val, client_id)
|
||||
return out
|
||||
@@ -1,16 +1,31 @@
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
from ray.experimental.client import ray
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from ray import cloudpickle
|
||||
from ray.experimental.client.options import validate_options
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
from ray.util.inspect import is_cython
|
||||
import json
|
||||
import threading
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
|
||||
class ClientBaseRef:
|
||||
def __init__(self, id, handle=None):
|
||||
self.id = id
|
||||
self.handle = handle
|
||||
def __init__(self, id: bytes):
|
||||
self.id = None
|
||||
if not isinstance(id, bytes):
|
||||
raise TypeError("ClientRefs must be created with bytes IDs")
|
||||
self.id: bytes = id
|
||||
ray.call_retain(id)
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
|
||||
def __repr__(self):
|
||||
return "%s(%s)" % (
|
||||
@@ -18,20 +33,16 @@ class ClientBaseRef:
|
||||
self.id.hex(),
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.id == other.id
|
||||
def __hash__(self):
|
||||
return hash(self.id)
|
||||
|
||||
def binary(self):
|
||||
return self.id
|
||||
|
||||
@classmethod
|
||||
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
|
||||
return cls(id=ref.id, handle=ref.handle)
|
||||
def __del__(self):
|
||||
if ray.is_connected() and self.id is not None:
|
||||
ray.call_release(self.id)
|
||||
|
||||
|
||||
class ClientObjectRef(ClientBaseRef):
|
||||
def _unpack_ref(self):
|
||||
return cloudpickle.loads(self.handle)
|
||||
pass
|
||||
|
||||
|
||||
class ClientActorRef(ClientBaseRef):
|
||||
@@ -43,8 +54,7 @@ class ClientStub:
|
||||
|
||||
|
||||
class ClientRemoteFunc(ClientStub):
|
||||
"""
|
||||
A stub created on the Ray Client to represent a remote
|
||||
"""A stub created on the Ray Client to represent a remote
|
||||
function that can be exectued on the cluster.
|
||||
|
||||
This class is allowed to be passed around between remote functions.
|
||||
@@ -53,55 +63,57 @@ class ClientRemoteFunc(ClientStub):
|
||||
_func: The actual function to execute remotely
|
||||
_name: The original name of the function
|
||||
_ref: The ClientObjectRef of the pickled code of the function, _func
|
||||
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
|
||||
for this object
|
||||
"""
|
||||
|
||||
def __init__(self, f):
|
||||
def __init__(self, f, options=None):
|
||||
self._lock = threading.Lock()
|
||||
self._func = f
|
||||
self._name = f.__name__
|
||||
self.id = None
|
||||
|
||||
# self._ref can be lazily instantiated. Rather than eagerly creating
|
||||
# function data objects in the server we can put them just before we
|
||||
# execute the function, especially in cases where many @ray.remote
|
||||
# functions exist in a library and only a handful are ever executed by
|
||||
# a user of the library.
|
||||
#
|
||||
# TODO(barakmich): This ref might actually be better as a serialized
|
||||
# ObjectRef. This requires being able to serialize the ref without
|
||||
# pinning it (as the lifetime of the ref is tied with the server, not
|
||||
# the client)
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
self._options = validate_options(options)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote function cannot be called directly. "
|
||||
"Use {self._name}.remote method instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._raylet_remote is None:
|
||||
self._raylet_remote = ray.remote(self._func)
|
||||
return self._raylet_remote
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def _ensure_ref(self):
|
||||
with self._lock:
|
||||
if self._ref is None:
|
||||
# While calling ray.put() on our function, if
|
||||
# our function is recursive, it will attempt to
|
||||
# encode the ClientRemoteFunc -- itself -- and
|
||||
# infinitely recurse on _ensure_ref.
|
||||
#
|
||||
# So we set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self._func)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self._func)
|
||||
self._ensure_ref()
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.FUNCTION
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.handle
|
||||
task.payload_id = self._ref.id
|
||||
set_task_options(task, self._options, "baseline_options")
|
||||
return task
|
||||
|
||||
|
||||
class ClientActorClass(ClientStub):
|
||||
""" A stub created on the Ray Client to represent an actor class.
|
||||
"""A stub created on the Ray Client to represent an actor class.
|
||||
|
||||
It is wrapped by ray.remote and can be executed on the cluster.
|
||||
|
||||
@@ -109,39 +121,40 @@ class ClientActorClass(ClientStub):
|
||||
actor_cls: The actual class to execute remotely
|
||||
_name: The original name of the class
|
||||
_ref: The ClientObjectRef of the pickled `actor_cls`
|
||||
_raylet_remote: The Raylet-side ray.ActorClass for this object
|
||||
"""
|
||||
|
||||
def __init__(self, actor_cls):
|
||||
def __init__(self, actor_cls, options=None):
|
||||
self.actor_cls = actor_cls
|
||||
self._name = actor_cls.__name__
|
||||
self._ref = None
|
||||
self._raylet_remote = None
|
||||
self._options = validate_options(options)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote actor cannot be instantiated directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_cls": self.actor_cls,
|
||||
"_name": self._name,
|
||||
"_ref": self._ref,
|
||||
}
|
||||
return state
|
||||
def _ensure_ref(self):
|
||||
if self._ref is None:
|
||||
# As before, set the state of the reference to be an
|
||||
# in-progress self reference value, which
|
||||
# the encoding can detect and handle correctly.
|
||||
self._ref = SelfReferenceSentinel()
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_cls = state["actor_cls"]
|
||||
self._name = state["_name"]
|
||||
self._ref = state["_ref"]
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
def remote(self, *args, **kwargs) -> "ClientActorHandle":
|
||||
# Actually instantiate the actor
|
||||
ref = ray.call_remote(self, *args, **kwargs)
|
||||
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]))
|
||||
|
||||
def options(self, **kwargs):
|
||||
return ActorOptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
|
||||
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key not in self.__dict__:
|
||||
@@ -149,12 +162,12 @@ class ClientActorClass(ClientStub):
|
||||
raise NotImplementedError("static methods")
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
if self._ref is None:
|
||||
self._ref = ray.put(self.actor_cls)
|
||||
self._ensure_ref()
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.ACTOR
|
||||
task.name = self._name
|
||||
task.payload_id = self._ref.handle
|
||||
task.payload_id = self._ref.id
|
||||
set_task_options(task, self._options, "baseline_options")
|
||||
return task
|
||||
|
||||
|
||||
@@ -174,29 +187,12 @@ class ClientActorHandle(ClientStub):
|
||||
ray.actor.ActorHandle contained in the actor_id ref.
|
||||
"""
|
||||
|
||||
def __init__(self, actor_ref: ClientActorRef,
|
||||
actor_class: ClientActorClass):
|
||||
def __init__(self, actor_ref: ClientActorRef):
|
||||
self.actor_ref = actor_ref
|
||||
self.actor_class = actor_class
|
||||
self._real_actor_handle = None
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
if self._real_actor_handle is None:
|
||||
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
|
||||
return self._real_actor_handle
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_ref": self.actor_ref,
|
||||
"actor_class": self.actor_class,
|
||||
"_real_actor_handle": self._real_actor_handle,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_ref = state["actor_ref"]
|
||||
self.actor_class = state["actor_class"]
|
||||
self._real_actor_handle = state["_real_actor_handle"]
|
||||
def __del__(self) -> None:
|
||||
if ray.is_connected():
|
||||
ray.call_release(self.actor_ref.id)
|
||||
|
||||
@property
|
||||
def _actor_id(self):
|
||||
@@ -226,65 +222,90 @@ class ClientRemoteMethod(ClientStub):
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise TypeError(f"Remote method cannot be called directly. "
|
||||
"Use {self._name}.remote() instead")
|
||||
|
||||
def _get_ray_remote_impl(self):
|
||||
return getattr(self.actor_handle._get_ray_remote_impl(),
|
||||
self.method_name)
|
||||
|
||||
def __getstate__(self) -> Dict:
|
||||
state = {
|
||||
"actor_handle": self.actor_handle,
|
||||
"method_name": self.method_name,
|
||||
}
|
||||
return state
|
||||
|
||||
def __setstate__(self, state: Dict) -> None:
|
||||
self.actor_handle = state["actor_handle"]
|
||||
self.method_name = state["method_name"]
|
||||
f"Use {self._name}.remote() instead")
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.call_remote(self, *args, **kwargs)
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __repr__(self):
|
||||
name = "%s.%s" % (self.actor_handle.actor_class._name,
|
||||
self.method_name)
|
||||
return "ClientRemoteMethod(%s, %s)" % (name,
|
||||
self.actor_handle.actor_id)
|
||||
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
|
||||
self.actor_handle)
|
||||
|
||||
def options(self, **kwargs):
|
||||
return OptionWrapper(self, kwargs)
|
||||
|
||||
def _remote(self, args=[], kwargs={}, **option_args):
|
||||
return self.options(**option_args).remote(*args, **kwargs)
|
||||
|
||||
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.METHOD
|
||||
task.name = self.method_name
|
||||
task.payload_id = self.actor_handle.actor_ref.handle
|
||||
task.payload_id = self.actor_handle.actor_ref.id
|
||||
return task
|
||||
|
||||
|
||||
def convert_from_arg(pb) -> Any:
|
||||
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
|
||||
return ClientObjectRef(pb.reference_id)
|
||||
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
|
||||
return cloudpickle.loads(pb.data)
|
||||
class OptionWrapper:
|
||||
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
|
||||
self.remote_stub = stub
|
||||
self.options = validate_options(options)
|
||||
|
||||
raise Exception("convert_from_arg: Uncovered locality enum")
|
||||
def remote(self, *args, **kwargs):
|
||||
return return_refs(ray.call_remote(self, *args, **kwargs))
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self.remote_stub, key)
|
||||
|
||||
def _prepare_client_task(self):
|
||||
task = self.remote_stub._prepare_client_task()
|
||||
set_task_options(task, self.options)
|
||||
return task
|
||||
|
||||
|
||||
def convert_to_arg(val):
|
||||
out = ray_client_pb2.Arg()
|
||||
if isinstance(val, ClientObjectRef):
|
||||
out.local = ray_client_pb2.Arg.Locality.REFERENCE
|
||||
out.reference_id = val.id
|
||||
else:
|
||||
out.local = ray_client_pb2.Arg.Locality.INTERNED
|
||||
out.data = cloudpickle.dumps(val)
|
||||
return out
|
||||
class ActorOptionWrapper(OptionWrapper):
|
||||
def remote(self, *args, **kwargs):
|
||||
ref_ids = ray.call_remote(self, *args, **kwargs)
|
||||
assert len(ref_ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ref_ids[0]))
|
||||
|
||||
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
def set_task_options(task: ray_client_pb2.ClientTask,
|
||||
options: Optional[Dict[str, Any]],
|
||||
field: str = "options") -> None:
|
||||
if options is None:
|
||||
task.ClearField(field)
|
||||
return
|
||||
options_str = json.dumps(options)
|
||||
getattr(task, field).json_options = options_str
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return cloudpickle.loads(data)
|
||||
def return_refs(ids: List[bytes]
|
||||
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
|
||||
if len(ids) == 1:
|
||||
return ClientObjectRef(ids[0])
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
return [ClientObjectRef(id) for id in ids]
|
||||
|
||||
|
||||
class DataEncodingSentinel:
|
||||
def __repr__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
|
||||
class SelfReferenceSentinel(DataEncodingSentinel):
|
||||
pass
|
||||
|
||||
|
||||
def remote_decorator(options: Optional[Dict[str, Any]]):
|
||||
def decorator(function_or_class) -> ClientStub:
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class, options=options)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class, options=options)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
"""This file implements a threaded stream controller to abstract a data stream
|
||||
back to the ray clientserver.
|
||||
"""
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import grpc
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The maximum field value for request_id -- which is also the maximum
|
||||
# number of simultaneous in-flight requests.
|
||||
INT32_MAX = (2**31) - 1
|
||||
|
||||
|
||||
class DataClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
|
||||
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
client_id: the generated ID representing this client
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
self.data_thread = self._start_datathread()
|
||||
self.ready_data: Dict[int, Any] = {}
|
||||
self.cv = threading.Condition()
|
||||
self._req_id = 0
|
||||
self._client_id = client_id
|
||||
self.data_thread.start()
|
||||
|
||||
def _next_id(self) -> int:
|
||||
self._req_id += 1
|
||||
if self._req_id > INT32_MAX:
|
||||
self._req_id = 1
|
||||
# Responses that aren't tracked (like opportunistic releases)
|
||||
# have req_id=0, so make sure we never mint such an id.
|
||||
assert self._req_id != 0
|
||||
return self._req_id
|
||||
|
||||
def _start_datathread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._data_main, args=(), daemon=True)
|
||||
|
||||
def _data_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
|
||||
resp_stream = stub.Datapath(
|
||||
iter(self.request_queue.get, None),
|
||||
metadata=(("client_id", self._client_id), ))
|
||||
try:
|
||||
for response in resp_stream:
|
||||
if response.req_id == 0:
|
||||
# This is not being waited for.
|
||||
logger.debug(f"Got unawaited response {response}")
|
||||
continue
|
||||
with self.cv:
|
||||
self.ready_data[response.req_id] = response
|
||||
self.cv.notify_all()
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED == e.code():
|
||||
# Gracefully shutting down
|
||||
logger.info("Cancelling data channel")
|
||||
else:
|
||||
logger.error(
|
||||
f"Got Error from data channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def close(self) -> None:
|
||||
if self.request_queue is not None:
|
||||
self.request_queue.put(None)
|
||||
if self.data_thread is not None:
|
||||
self.data_thread.join()
|
||||
|
||||
def _blocking_send(self, req: ray_client_pb2.DataRequest
|
||||
) -> ray_client_pb2.DataResponse:
|
||||
req_id = self._next_id()
|
||||
req.req_id = req_id
|
||||
self.request_queue.put(req)
|
||||
data = None
|
||||
with self.cv:
|
||||
self.cv.wait_for(lambda: req_id in self.ready_data)
|
||||
data = self.ready_data[req_id]
|
||||
del self.ready_data[req_id]
|
||||
return data
|
||||
|
||||
def GetObject(self, request: ray_client_pb2.GetRequest,
|
||||
context=None) -> ray_client_pb2.GetResponse:
|
||||
datareq = ray_client_pb2.DataRequest(get=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.get
|
||||
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
datareq = ray_client_pb2.DataRequest(put=request, )
|
||||
resp = self._blocking_send(datareq)
|
||||
return resp.put
|
||||
|
||||
def ReleaseObject(self,
|
||||
request: ray_client_pb2.ReleaseRequest,
|
||||
context=None) -> None:
|
||||
datareq = ray_client_pb2.DataRequest(release=request, )
|
||||
self.request_queue.put(datareq)
|
||||
@@ -0,0 +1,7 @@
|
||||
from ray.experimental.client import ray
|
||||
|
||||
from ray.tune import tune
|
||||
|
||||
ray.connect("localhost:50051")
|
||||
|
||||
tune.run("PG", config={"env": "CartPole-v0"})
|
||||
@@ -0,0 +1,86 @@
|
||||
"""This file implements a threaded stream controller to return logs back from
|
||||
the ray clientserver.
|
||||
"""
|
||||
import sys
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
import grpc
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# TODO(barakmich): Running a logger in a logger causes loopback.
|
||||
# The client logger need its own root -- possibly this one.
|
||||
# For the moment, let's just not propogate beyond this point.
|
||||
logger.propagate = False
|
||||
|
||||
|
||||
class LogstreamClient:
|
||||
def __init__(self, channel: "grpc._channel.Channel"):
|
||||
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
|
||||
|
||||
Args:
|
||||
channel: connected gRPC channel
|
||||
"""
|
||||
self.channel = channel
|
||||
self.request_queue = queue.Queue()
|
||||
self.log_thread = self._start_logthread()
|
||||
self.log_thread.start()
|
||||
|
||||
def _start_logthread(self) -> threading.Thread:
|
||||
return threading.Thread(target=self._log_main, args=(), daemon=True)
|
||||
|
||||
def _log_main(self) -> None:
|
||||
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
|
||||
log_stream = stub.Logstream(iter(self.request_queue.get, None))
|
||||
try:
|
||||
for record in log_stream:
|
||||
if record.level < 0:
|
||||
self.stdstream(level=record.level, msg=record.msg)
|
||||
self.log(level=record.level, msg=record.msg)
|
||||
except grpc.RpcError as e:
|
||||
if grpc.StatusCode.CANCELLED != e.code():
|
||||
# Not just shutting down normally
|
||||
logger.error(
|
||||
f"Got Error from logger channel -- shutting down: {e}")
|
||||
raise e
|
||||
|
||||
def log(self, level: int, msg: str):
|
||||
"""Log the message from the log stream.
|
||||
By default, calls logger.log but this can be overridden.
|
||||
|
||||
Args:
|
||||
level: The loglevel of the received log message
|
||||
msg: The content of the message
|
||||
"""
|
||||
logger.log(level=level, msg=msg)
|
||||
|
||||
def stdstream(self, level: int, msg: str):
|
||||
"""Log the stdout/stderr entry from the log stream.
|
||||
By default, calls print but this can be overridden.
|
||||
|
||||
Args:
|
||||
level: The loglevel of the received log message
|
||||
msg: The content of the message
|
||||
"""
|
||||
print_file = sys.stderr if level == -2 else sys.stdout
|
||||
print(msg, file=print_file)
|
||||
|
||||
def set_logstream_level(self, level: int):
|
||||
logger.setLevel(level)
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = True
|
||||
req.loglevel = level
|
||||
self.request_queue.put(req)
|
||||
|
||||
def close(self) -> None:
|
||||
self.request_queue.put(None)
|
||||
if self.log_thread is not None:
|
||||
self.log_thread.join()
|
||||
|
||||
def disable_logs(self) -> None:
|
||||
req = ray_client_pb2.LogSettingsRequest()
|
||||
req.enabled = False
|
||||
self.request_queue.put(req)
|
||||
@@ -0,0 +1,54 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
options = {
|
||||
"num_returns": (int, lambda x: x >= 0,
|
||||
"The keyword 'num_returns' only accepts 0 "
|
||||
"or a positive integer"),
|
||||
"num_cpus": (),
|
||||
"num_gpus": (),
|
||||
"resources": (),
|
||||
"accelerator_type": (),
|
||||
"max_calls": (int, lambda x: x >= 0,
|
||||
"The keyword 'max_calls' only accepts 0 "
|
||||
"or a positive integer"),
|
||||
"max_restarts": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_restarts' only accepts -1, 0 "
|
||||
"or a positive integer"),
|
||||
"max_task_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_task_retries' only accepts -1, 0 "
|
||||
"or a positive integer"),
|
||||
"max_retries": (int, lambda x: x >= -1,
|
||||
"The keyword 'max_retries' only accepts 0, -1 "
|
||||
"or a positive integer"),
|
||||
"max_concurrency": (),
|
||||
"name": (),
|
||||
"lifetime": (),
|
||||
"memory": (),
|
||||
"object_store_memory": (),
|
||||
"placement_group": (),
|
||||
"placement_group_bundle_index": (),
|
||||
"placement_group_capture_child_tasks": (),
|
||||
"override_environment_variables": (),
|
||||
}
|
||||
|
||||
|
||||
def validate_options(
|
||||
kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
||||
if kwargs_dict is None:
|
||||
return None
|
||||
if len(kwargs_dict) == 0:
|
||||
return None
|
||||
out = {}
|
||||
for k, v in kwargs_dict.items():
|
||||
if k not in options.keys():
|
||||
raise TypeError(f"Invalid option passed to remote(): {k}")
|
||||
validator = options[k]
|
||||
if len(validator) != 0:
|
||||
if not isinstance(v, validator[0]):
|
||||
raise ValueError(validator[2])
|
||||
if not validator[1](v):
|
||||
raise ValueError(validator[2])
|
||||
out[k] = v
|
||||
return out
|
||||
@@ -0,0 +1,17 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import ray.experimental.client.server.server as ray_client_server
|
||||
from ray.experimental.client import ray
|
||||
|
||||
|
||||
@contextmanager
|
||||
def ray_start_client_server():
|
||||
ray._inside_client_test = True
|
||||
server = ray_client_server.serve("localhost:50051")
|
||||
ray.connect("localhost:50051")
|
||||
try:
|
||||
yield ray
|
||||
finally:
|
||||
ray._inside_client_test = False
|
||||
ray.disconnect()
|
||||
server.stop(0)
|
||||
@@ -1,101 +0,0 @@
|
||||
# Along with `api.py` this is the stub that interfaces with
|
||||
# the real (C-binding, raylet) ray core.
|
||||
#
|
||||
# Ideally, the first import line is the only time we actually
|
||||
# import ray in this library (excluding the main function for the server)
|
||||
#
|
||||
# While the stub is trivial, it allows us to check that the calls we're
|
||||
# making into the core-ray module are contained and well-defined.
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
import ray
|
||||
|
||||
from ray.experimental.client.api import APIImpl
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientStub
|
||||
|
||||
|
||||
class CoreRayAPI(APIImpl):
|
||||
"""
|
||||
Implements the equivalent client-side Ray API by simply passing along to
|
||||
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
|
||||
to core ray when passed client stubs.
|
||||
"""
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
if isinstance(vals, list):
|
||||
if isinstance(vals[0], ClientObjectRef):
|
||||
return ray.get(
|
||||
[val._unpack_ref() for val in vals], timeout=timeout)
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
return ray.get(vals._unpack_ref(), timeout=timeout)
|
||||
return ray.get(vals, timeout=timeout)
|
||||
|
||||
def put(self, vals: Any, *args,
|
||||
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
|
||||
return ray.put(vals, *args, **kwargs)
|
||||
|
||||
def wait(self, *args, **kwargs):
|
||||
return ray.wait(*args, **kwargs)
|
||||
|
||||
def remote(self, *args, **kwargs):
|
||||
return ray.remote(*args, **kwargs)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
return instance._get_ray_remote_impl().remote(*args, **kwargs)
|
||||
|
||||
def close(self) -> None:
|
||||
return None
|
||||
|
||||
def kill(self, actor, *, no_restart=True):
|
||||
return ray.kill(actor, no_restart=no_restart)
|
||||
|
||||
def cancel(self, obj, *, force=False, recursive=True):
|
||||
return ray.cancel(obj, force=force, recursive=recursive)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return ray.is_initialized()
|
||||
|
||||
# Allow for generic fallback to ray.* in remote methods. This allows calls
|
||||
# like ray.nodes() to be run in remote functions even though the client
|
||||
# doesn't currently support them.
|
||||
def __getattr__(self, key: str):
|
||||
return getattr(ray, key)
|
||||
|
||||
|
||||
class RayServerAPI(CoreRayAPI):
|
||||
"""
|
||||
Ray Client server-side API shim. By default, simply calls the default Core
|
||||
Ray API calls, but also accepts scheduling calls from functions running
|
||||
inside of other remote functions that need to create more work.
|
||||
"""
|
||||
|
||||
def __init__(self, server_instance):
|
||||
self.server = server_instance
|
||||
|
||||
# Wrap single item into list if needed before calling server put.
|
||||
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
|
||||
to_put = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_put = vals
|
||||
else:
|
||||
single = True
|
||||
to_put.append(vals)
|
||||
|
||||
out = [self._put(x) for x in to_put]
|
||||
if single:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _put(self, val: Any):
|
||||
resp = self.server._put_and_retain_obj(val)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def call_remote(self, instance: ClientStub, *args, **kwargs):
|
||||
task = instance._prepare_client_task()
|
||||
ticket = self.server.Schedule(task, prepared_args=args)
|
||||
return ClientObjectRef(ticket.return_id)
|
||||
@@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import grpc
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
|
||||
def __init__(self, basic_service: "RayletServicer"):
|
||||
self.basic_service = basic_service
|
||||
|
||||
def Datapath(self, request_iterator, context):
|
||||
metadata = {k: v for k, v in context.invocation_metadata()}
|
||||
client_id = metadata["client_id"]
|
||||
if client_id == "":
|
||||
logger.error("Client connecting with no client_id")
|
||||
return
|
||||
logger.info(f"New data connection from client {client_id}")
|
||||
try:
|
||||
for req in request_iterator:
|
||||
resp = None
|
||||
req_type = req.WhichOneof("type")
|
||||
if req_type == "get":
|
||||
get_resp = self.basic_service._get_object(
|
||||
req.get, client_id)
|
||||
resp = ray_client_pb2.DataResponse(get=get_resp)
|
||||
elif req_type == "put":
|
||||
put_resp = self.basic_service._put_object(
|
||||
req.put, client_id)
|
||||
resp = ray_client_pb2.DataResponse(put=put_resp)
|
||||
elif req_type == "release":
|
||||
released = []
|
||||
for rel_id in req.release.ids:
|
||||
rel = self.basic_service.release(client_id, rel_id)
|
||||
released.append(rel)
|
||||
resp = ray_client_pb2.DataResponse(
|
||||
release=ray_client_pb2.ReleaseResponse(ok=released))
|
||||
else:
|
||||
raise Exception(f"Unreachable code: Request type "
|
||||
f"{req_type} not handled in Datapath")
|
||||
resp.req_id = req.req_id
|
||||
yield resp
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing data channel: {e}")
|
||||
finally:
|
||||
logger.info(f"Lost data connection from client {client_id}")
|
||||
self.basic_service.release_all(client_id)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""This file responds to log stream requests and forwards logs
|
||||
with its handler.
|
||||
"""
|
||||
import io
|
||||
import threading
|
||||
import queue
|
||||
import logging
|
||||
import grpc
|
||||
import uuid
|
||||
|
||||
from ray.worker import print_worker_logs
|
||||
from ray.ray_logging import global_worker_stdstream_dispatcher
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LogstreamHandler(logging.Handler):
|
||||
def __init__(self, queue, level):
|
||||
super().__init__()
|
||||
self.queue = queue
|
||||
self.level = level
|
||||
|
||||
def emit(self, record: logging.LogRecord):
|
||||
logdata = ray_client_pb2.LogData()
|
||||
logdata.msg = record.getMessage()
|
||||
logdata.level = record.levelno
|
||||
logdata.name = record.name
|
||||
self.queue.put(logdata)
|
||||
|
||||
|
||||
class StdStreamHandler:
|
||||
def __init__(self, queue):
|
||||
self.queue = queue
|
||||
self.id = str(uuid.uuid4())
|
||||
|
||||
def handle(self, data):
|
||||
logdata = ray_client_pb2.LogData()
|
||||
logdata.level = -2 if data["is_err"] else -1
|
||||
logdata.name = "stderr" if data["is_err"] else "stdout"
|
||||
with io.StringIO() as file:
|
||||
print_worker_logs(data, file)
|
||||
logdata.msg = file.getvalue()
|
||||
self.queue.put(logdata)
|
||||
|
||||
def register_global(self):
|
||||
global_worker_stdstream_dispatcher.add_handler(self.id, self.handle)
|
||||
|
||||
def unregister_global(self):
|
||||
global_worker_stdstream_dispatcher.remove_handler(self.id)
|
||||
|
||||
|
||||
def log_status_change_thread(log_queue, request_iterator):
|
||||
std_handler = StdStreamHandler(log_queue)
|
||||
current_handler = None
|
||||
root_logger = logging.getLogger("ray")
|
||||
default_level = root_logger.getEffectiveLevel()
|
||||
try:
|
||||
for req in request_iterator:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
root_logger.removeHandler(current_handler)
|
||||
std_handler.unregister_global()
|
||||
if not req.enabled:
|
||||
current_handler = None
|
||||
continue
|
||||
current_handler = LogstreamHandler(log_queue, req.loglevel)
|
||||
std_handler.register_global()
|
||||
root_logger.addHandler(current_handler)
|
||||
root_logger.setLevel(req.loglevel)
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"closing log thread "
|
||||
f"grpc error reading request_iterator: {e}")
|
||||
finally:
|
||||
if current_handler is not None:
|
||||
root_logger.setLevel(default_level)
|
||||
root_logger.removeHandler(current_handler)
|
||||
std_handler.unregister_global()
|
||||
log_queue.put(None)
|
||||
|
||||
|
||||
class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
|
||||
def Logstream(self, request_iterator, context):
|
||||
logger.info("New logs connection")
|
||||
log_queue = queue.Queue()
|
||||
thread = threading.Thread(
|
||||
target=log_status_change_thread,
|
||||
args=(log_queue, request_iterator),
|
||||
daemon=True)
|
||||
thread.start()
|
||||
try:
|
||||
queue_iter = iter(log_queue.get, None)
|
||||
for record in queue_iter:
|
||||
if record is None:
|
||||
break
|
||||
yield record
|
||||
except grpc.RpcError as e:
|
||||
logger.debug(f"Closing log channel: {e}")
|
||||
finally:
|
||||
thread.join()
|
||||
@@ -1,6 +1,14 @@
|
||||
import logging
|
||||
from concurrent import futures
|
||||
import grpc
|
||||
import base64
|
||||
from collections import defaultdict
|
||||
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
from typing import Optional
|
||||
|
||||
from ray import cloudpickle
|
||||
import ray
|
||||
import ray.state
|
||||
@@ -9,29 +17,34 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
import time
|
||||
import inspect
|
||||
import json
|
||||
from ray.experimental.client import stash_api_for_tests, _set_server_api
|
||||
from ray.experimental.client.common import convert_from_arg
|
||||
from ray.experimental.client.common import encode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.server.core_ray_api import RayServerAPI
|
||||
from ray.experimental.client.server.server_pickler import convert_from_arg
|
||||
from ray.experimental.client.server.server_pickler import dumps_from_server
|
||||
from ray.experimental.client.server.server_pickler import loads_from_client
|
||||
from ray.experimental.client.server.dataservicer import DataServicer
|
||||
from ray.experimental.client.server.logservicer import LogstreamServicer
|
||||
from ray.experimental.client.server.server_stubs import current_remote
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
def __init__(self, test_mode=False):
|
||||
self.object_refs = {}
|
||||
def __init__(self):
|
||||
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
|
||||
dict)
|
||||
self.function_refs = {}
|
||||
self.actor_refs = {}
|
||||
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
|
||||
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
|
||||
self.registered_actor_classes = {}
|
||||
self._test_mode = test_mode
|
||||
self._current_function_stub = None
|
||||
|
||||
def ClusterInfo(self, request,
|
||||
context=None) -> ray_client_pb2.ClusterInfoResponse:
|
||||
resp = ray_client_pb2.ClusterInfoResponse()
|
||||
resp.type = request.type
|
||||
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
|
||||
resources = ray.cluster_resources()
|
||||
with disable_client_hook():
|
||||
resources = ray.cluster_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
@@ -40,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
table=float_resources))
|
||||
elif request.type == \
|
||||
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
|
||||
resources = ray.available_resources()
|
||||
with disable_client_hook():
|
||||
resources = ray.available_resources()
|
||||
# Normalize resources into floats
|
||||
# (the function may return values that are ints)
|
||||
float_resources = {k: float(v) for k, v in resources.items()}
|
||||
@@ -48,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
ray_client_pb2.ClusterInfoResponse.ResourceTable(
|
||||
table=float_resources))
|
||||
else:
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
with disable_client_hook():
|
||||
resp.json = self._return_debug_cluster_info(request, context)
|
||||
return resp
|
||||
|
||||
def _return_debug_cluster_info(self, request, context=None) -> str:
|
||||
@@ -61,20 +76,61 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
raise TypeError("Unsupported cluster info type")
|
||||
return json.dumps(data)
|
||||
|
||||
def Terminate(self, request, context=None):
|
||||
if request.WhichOneof("terminate_type") == "task_object":
|
||||
def release(self, client_id: str, id: bytes) -> bool:
|
||||
if client_id in self.object_refs:
|
||||
if id in self.object_refs[client_id]:
|
||||
logger.debug(f"Releasing object {id.hex()} for {client_id}")
|
||||
del self.object_refs[client_id][id]
|
||||
return True
|
||||
|
||||
if client_id in self.actor_owners:
|
||||
if id in self.actor_owners[client_id]:
|
||||
logger.debug(f"Releasing actor {id.hex()} for {client_id}")
|
||||
del self.actor_refs[id]
|
||||
self.actor_owners[client_id].remove(id)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def release_all(self, client_id):
|
||||
self._release_objects(client_id)
|
||||
self._release_actors(client_id)
|
||||
|
||||
def _release_objects(self, client_id):
|
||||
if client_id not in self.object_refs:
|
||||
logger.debug(f"Releasing client with no references: {client_id}")
|
||||
return
|
||||
count = len(self.object_refs[client_id])
|
||||
del self.object_refs[client_id]
|
||||
logger.debug(f"Released all {count} objects for client {client_id}")
|
||||
|
||||
def _release_actors(self, client_id):
|
||||
if client_id not in self.actor_owners:
|
||||
logger.debug(f"Releasing client with no actors: {client_id}")
|
||||
count = 0
|
||||
for id_bytes in self.actor_owners[client_id]:
|
||||
count += 1
|
||||
del self.actor_refs[id_bytes]
|
||||
del self.actor_owners[client_id]
|
||||
logger.debug(f"Released all {count} actors for client: {client_id}")
|
||||
|
||||
def Terminate(self, req, context=None):
|
||||
if req.WhichOneof("terminate_type") == "task_object":
|
||||
try:
|
||||
object_ref = cloudpickle.loads(request.task_object.handle)
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=request.task_object.force,
|
||||
recursive=request.task_object.recursive)
|
||||
object_ref = \
|
||||
self.object_refs[req.client_id][req.task_object.id]
|
||||
with disable_client_hook():
|
||||
ray.cancel(
|
||||
object_ref,
|
||||
force=req.task_object.force,
|
||||
recursive=req.task_object.recursive)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
elif request.WhichOneof("terminate_type") == "actor":
|
||||
elif req.WhichOneof("terminate_type") == "actor":
|
||||
try:
|
||||
actor_ref = cloudpickle.loads(request.actor.handle)
|
||||
ray.kill(actor_ref, no_restart=request.actor.no_restart)
|
||||
actor_ref = self.actor_refs[req.actor.id]
|
||||
with disable_client_hook():
|
||||
ray.kill(actor_ref, no_restart=req.actor.no_restart)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
else:
|
||||
@@ -84,166 +140,221 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
|
||||
return ray_client_pb2.TerminateResponse(ok=True)
|
||||
|
||||
def GetObject(self, request, context=None):
|
||||
request_ref = cloudpickle.loads(request.handle)
|
||||
if request_ref.binary() not in self.object_refs:
|
||||
return self._get_object(request, "", context)
|
||||
|
||||
def _get_object(self, request, client_id: str, context=None):
|
||||
if request.id not in self.object_refs[client_id]:
|
||||
return ray_client_pb2.GetResponse(valid=False)
|
||||
objectref = self.object_refs[request_ref.binary()]
|
||||
logger.info("get: %s" % objectref)
|
||||
objectref = self.object_refs[client_id][request.id]
|
||||
logger.debug("get: %s" % objectref)
|
||||
try:
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
with disable_client_hook():
|
||||
item = ray.get(objectref, timeout=request.timeout)
|
||||
except Exception as e:
|
||||
return_exception_in_context(e, context)
|
||||
item_ser = cloudpickle.dumps(item)
|
||||
return ray_client_pb2.GetResponse(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
item_ser = dumps_from_server(item, client_id, self)
|
||||
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
|
||||
|
||||
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
|
||||
obj = cloudpickle.loads(request.data)
|
||||
objectref = self._put_and_retain_obj(obj)
|
||||
pickled_ref = cloudpickle.dumps(objectref)
|
||||
return ray_client_pb2.PutResponse(
|
||||
ref=make_remote_ref(objectref.binary(), pickled_ref))
|
||||
def PutObject(self, request: ray_client_pb2.PutRequest,
|
||||
context=None) -> ray_client_pb2.PutResponse:
|
||||
"""gRPC entrypoint for unary PutObject
|
||||
"""
|
||||
return self._put_object(request, "", context)
|
||||
|
||||
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[objectref.binary()] = objectref
|
||||
logger.info("put: %s" % objectref)
|
||||
return objectref
|
||||
def _put_object(self,
|
||||
request: ray_client_pb2.PutRequest,
|
||||
client_id: str,
|
||||
context=None):
|
||||
"""Put an object in the cluster with ray.put() via gRPC.
|
||||
|
||||
Args:
|
||||
request: PutRequest with pickled data.
|
||||
client_id: The client who owns this data, for tracking when to
|
||||
delete this reference.
|
||||
context: gRPC context.
|
||||
"""
|
||||
obj = loads_from_client(request.data, self)
|
||||
with disable_client_hook():
|
||||
objectref = ray.put(obj)
|
||||
self.object_refs[client_id][objectref.binary()] = objectref
|
||||
logger.debug("put: %s" % objectref)
|
||||
return ray_client_pb2.PutResponse(id=objectref.binary())
|
||||
|
||||
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
|
||||
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
|
||||
object_refs = []
|
||||
for id in request.object_ids:
|
||||
if id not in self.object_refs[request.client_id]:
|
||||
raise Exception(
|
||||
"Asking for a ref not associated with this client: %s" %
|
||||
str(id))
|
||||
object_refs.append(self.object_refs[request.client_id][id])
|
||||
num_returns = request.num_returns
|
||||
timeout = request.timeout
|
||||
object_refs_ids = []
|
||||
for object_ref in object_refs:
|
||||
if object_ref.binary() not in self.object_refs:
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
object_refs_ids.append(self.object_refs[object_ref.binary()])
|
||||
try:
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs_ids,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None)
|
||||
except Exception:
|
||||
with disable_client_hook():
|
||||
ready_object_refs, remaining_object_refs = ray.wait(
|
||||
object_refs,
|
||||
num_returns=num_returns,
|
||||
timeout=timeout if timeout != -1 else None,
|
||||
)
|
||||
except Exception as e:
|
||||
# TODO(ameer): improve exception messages.
|
||||
logger.error(f"Exception {e}")
|
||||
return ray_client_pb2.WaitResponse(valid=False)
|
||||
logger.info("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
logger.debug("wait: %s %s" % (str(ready_object_refs),
|
||||
str(remaining_object_refs)))
|
||||
ready_object_ids = [
|
||||
make_remote_ref(
|
||||
id=ready_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(ready_object_ref),
|
||||
) for ready_object_ref in ready_object_refs
|
||||
ready_object_ref.binary() for ready_object_ref in ready_object_refs
|
||||
]
|
||||
remaining_object_ids = [
|
||||
make_remote_ref(
|
||||
id=remaining_object_ref.binary(),
|
||||
handle=cloudpickle.dumps(remaining_object_ref),
|
||||
) for remaining_object_ref in remaining_object_refs
|
||||
remaining_object_ref.binary()
|
||||
for remaining_object_ref in remaining_object_refs
|
||||
]
|
||||
return ray_client_pb2.WaitResponse(
|
||||
valid=True,
|
||||
ready_object_ids=ready_object_ids,
|
||||
remaining_object_ids=remaining_object_ids)
|
||||
|
||||
def Schedule(self, task, context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.info("schedule: %s %s" %
|
||||
(task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
return self._schedule_function(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
return self._schedule_actor(task, context, prepared_args)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
return self._schedule_method(task, context, prepared_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
|
||||
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
logger.debug(
|
||||
"schedule: %s %s" % (task.name,
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type)))
|
||||
try:
|
||||
with disable_client_hook():
|
||||
if task.type == ray_client_pb2.ClientTask.FUNCTION:
|
||||
result = self._schedule_function(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.ACTOR:
|
||||
result = self._schedule_actor(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.METHOD:
|
||||
result = self._schedule_method(task, context)
|
||||
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
|
||||
result = self._schedule_named_actor(task, context)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Unimplemented Schedule task type: %s" %
|
||||
ray_client_pb2.ClientTask.RemoteExecType.Name(
|
||||
task.type))
|
||||
result.valid = True
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Caught schedule exception {e}")
|
||||
raise e
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
valid=False, error=cloudpickle.dumps(e))
|
||||
|
||||
def _schedule_method(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
def _schedule_method(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
actor_handle = self.actor_refs.get(task.payload_id)
|
||||
if actor_handle is None:
|
||||
raise Exception(
|
||||
"Can't run an actor the server doesn't have a handle for")
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = getattr(actor_handle, task.name).remote(*arglist)
|
||||
self.object_refs[output.binary()] = output
|
||||
pickled_ref = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_ref))
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
method = getattr(actor_handle, task.name)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
method = method.options(**opts)
|
||||
output = method.remote(*arglist, **kwargs)
|
||||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _schedule_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[payload_ref.binary()]
|
||||
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_class = self.lookup_or_register_actor(
|
||||
task.payload_id, task.client_id,
|
||||
decode_options(task.baseline_options))
|
||||
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
remote_class = remote_class.options(**opts)
|
||||
with current_remote(remote_class):
|
||||
actor = remote_class.remote(*arglist, **kwargs)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ids=[actor._actor_id.binary()])
|
||||
|
||||
def _schedule_function(self, task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
remote_func = self.lookup_or_register_func(
|
||||
task.payload_id, task.client_id,
|
||||
decode_options(task.baseline_options))
|
||||
arglist, kwargs = self._convert_args(task.args, task.kwargs)
|
||||
opts = decode_options(task.options)
|
||||
if opts is not None:
|
||||
remote_func = remote_func.options(**opts)
|
||||
with current_remote(remote_func):
|
||||
output = remote_func.remote(*arglist, **kwargs)
|
||||
ids = self.unify_and_track_outputs(output, task.client_id)
|
||||
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
|
||||
|
||||
def _schedule_named_actor(self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
assert len(task.payload_id) == 0
|
||||
actor = ray.get_actor(task.name)
|
||||
self.actor_refs[actor._actor_id.binary()] = actor
|
||||
self.actor_owners[task.client_id].add(actor._actor_id.binary())
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ids=[actor._actor_id.binary()])
|
||||
|
||||
def _convert_args(self, arg_list, kwarg_map):
|
||||
argout = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg, self)
|
||||
argout.append(t)
|
||||
kwargout = {}
|
||||
for k in kwarg_map:
|
||||
kwargout[k] = convert_from_arg(kwarg_map[k], self)
|
||||
return argout, kwargout
|
||||
|
||||
def lookup_or_register_func(
|
||||
self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
|
||||
with disable_client_hook():
|
||||
if id not in self.function_refs:
|
||||
funcref = self.object_refs[client_id][id]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to register function that "
|
||||
"isn't a function.")
|
||||
if options is None or len(options) == 0:
|
||||
self.function_refs[id] = ray.remote(func)
|
||||
else:
|
||||
self.function_refs[id] = ray.remote(**options)(func)
|
||||
return self.function_refs[id]
|
||||
|
||||
def lookup_or_register_actor(self, id: bytes, client_id: str,
|
||||
options: Optional[Dict]):
|
||||
with disable_client_hook():
|
||||
if id not in self.registered_actor_classes:
|
||||
actor_class_ref = self.object_refs[client_id][id]
|
||||
actor_class = ray.get(actor_class_ref)
|
||||
if not inspect.isclass(actor_class):
|
||||
raise Exception("Attempting to schedule actor that "
|
||||
"isn't a class.")
|
||||
reg_class = ray.remote(actor_class)
|
||||
self.registered_actor_classes[payload_ref.binary()] = reg_class
|
||||
remote_class = self.registered_actor_classes[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
actor = remote_class.remote(*arglist)
|
||||
actorhandle = cloudpickle.dumps(actor)
|
||||
self.actor_refs[actorhandle] = actor
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
|
||||
if options is None or len(options) == 0:
|
||||
reg_class = ray.remote(actor_class)
|
||||
else:
|
||||
reg_class = ray.remote(**options)(actor_class)
|
||||
self.registered_actor_classes[id] = reg_class
|
||||
|
||||
def _schedule_function(
|
||||
self,
|
||||
task: ray_client_pb2.ClientTask,
|
||||
context=None,
|
||||
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
|
||||
payload_ref = cloudpickle.loads(task.payload_id)
|
||||
if payload_ref.binary() not in self.function_refs:
|
||||
funcref = self.object_refs[payload_ref.binary()]
|
||||
func = ray.get(funcref)
|
||||
if not inspect.isfunction(func):
|
||||
raise Exception("Attempting to schedule function that "
|
||||
"isn't a function.")
|
||||
self.function_refs[payload_ref.binary()] = ray.remote(func)
|
||||
remote_func = self.function_refs[payload_ref.binary()]
|
||||
arglist = _convert_args(task.args, prepared_args)
|
||||
# Prepare call if we're in a test
|
||||
with stash_api_for_tests(self._test_mode):
|
||||
output = remote_func.remote(*arglist)
|
||||
if output.binary() in self.object_refs:
|
||||
raise Exception("already found it")
|
||||
self.object_refs[output.binary()] = output
|
||||
pickled_output = cloudpickle.dumps(output)
|
||||
return ray_client_pb2.ClientTaskTicket(
|
||||
return_ref=make_remote_ref(output.binary(), pickled_output))
|
||||
return self.registered_actor_classes[id]
|
||||
|
||||
|
||||
def _convert_args(arg_list, prepared_args=None):
|
||||
if prepared_args is not None:
|
||||
return prepared_args
|
||||
out = []
|
||||
for arg in arg_list:
|
||||
t = convert_from_arg(arg)
|
||||
if isinstance(t, ClientObjectRef):
|
||||
out.append(t._unpack_ref())
|
||||
def unify_and_track_outputs(self, output, client_id):
|
||||
if output is None:
|
||||
outputs = []
|
||||
elif isinstance(output, list):
|
||||
outputs = output
|
||||
else:
|
||||
out.append(t)
|
||||
return out
|
||||
|
||||
|
||||
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
|
||||
return ray_client_pb2.RemoteRef(
|
||||
id=id,
|
||||
handle=handle,
|
||||
)
|
||||
outputs = [output]
|
||||
for out in outputs:
|
||||
if out.binary() in self.object_refs[client_id]:
|
||||
logger.warning(f"Already saw object_ref {out}")
|
||||
self.object_refs[client_id][out.binary()] = out
|
||||
return [out.binary() for out in outputs]
|
||||
|
||||
|
||||
def return_exception_in_context(err, context):
|
||||
@@ -252,17 +363,61 @@ def return_exception_in_context(err, context):
|
||||
context.set_code(grpc.StatusCode.INTERNAL)
|
||||
|
||||
|
||||
def serve(connection_str, test_mode=False):
|
||||
def encode_exception(exception) -> str:
|
||||
data = cloudpickle.dumps(exception)
|
||||
return base64.standard_b64encode(data).decode()
|
||||
|
||||
|
||||
def decode_options(
|
||||
options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
|
||||
if options.json_options == "":
|
||||
return None
|
||||
opts = json.loads(options.json_options)
|
||||
assert isinstance(opts, dict)
|
||||
return opts
|
||||
|
||||
|
||||
_current_servicer: Optional[RayletServicer] = None
|
||||
|
||||
|
||||
# Used by tests to peek inside the servicer
|
||||
def _get_current_servicer():
|
||||
global _current_servicer
|
||||
return _current_servicer
|
||||
|
||||
|
||||
def serve(connection_str):
|
||||
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||
task_servicer = RayletServicer(test_mode=test_mode)
|
||||
_set_server_api(RayServerAPI(task_servicer))
|
||||
task_servicer = RayletServicer()
|
||||
data_servicer = DataServicer(task_servicer)
|
||||
logs_servicer = LogstreamServicer()
|
||||
global _current_servicer
|
||||
_current_servicer = task_servicer
|
||||
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
|
||||
task_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
|
||||
data_servicer, server)
|
||||
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
|
||||
logs_servicer, server)
|
||||
server.add_insecure_port(connection_str)
|
||||
server.start()
|
||||
return server
|
||||
|
||||
|
||||
def init_and_serve(connection_str, *args, **kwargs):
|
||||
with disable_client_hook():
|
||||
# Disable client mode inside the worker's environment
|
||||
info = ray.init(*args, **kwargs)
|
||||
server = serve(connection_str)
|
||||
return (server, info)
|
||||
|
||||
|
||||
def shutdown_with_server(server, _exiting_interpreter=False):
|
||||
server.stop(1)
|
||||
with disable_client_hook():
|
||||
ray.shutdown(_exiting_interpreter)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level="INFO")
|
||||
# TODO(barakmich): Perhaps wrap ray init
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Implements the client side of the client/server pickling protocol.
|
||||
|
||||
These picklers are aware of the server internals and can find the
|
||||
references held for the client within the server.
|
||||
|
||||
More discussion about the client/server pickling protocol can be found in:
|
||||
|
||||
ray/experimental/client/client_pickler.py
|
||||
|
||||
ServerPickler dumps ray objects from the server into the appropriate stubs.
|
||||
ClientUnpickler loads stubs from the client and finds their associated handle
|
||||
in the server instance.
|
||||
"""
|
||||
import cloudpickle
|
||||
import io
|
||||
import sys
|
||||
import ray
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ray._private.client_mode_hook import disable_client_hook
|
||||
from ray.experimental.client.client_pickler import PickleStub
|
||||
from ray.experimental.client.server.server_stubs import (
|
||||
ServerSelfReferenceSentinel)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.experimental.client.server.server import RayletServicer
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
try:
|
||||
import pickle5 as pickle # noqa: F401
|
||||
except ImportError:
|
||||
import pickle # noqa: F401
|
||||
else:
|
||||
import pickle # noqa: F401
|
||||
|
||||
|
||||
class ServerPickler(cloudpickle.CloudPickler):
|
||||
def __init__(self, client_id: str, server: "RayletServicer", *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.client_id = client_id
|
||||
self.server = server
|
||||
|
||||
def persistent_id(self, obj):
|
||||
if isinstance(obj, ray.ObjectRef):
|
||||
obj_id = obj.binary()
|
||||
if obj_id not in self.server.object_refs[self.client_id]:
|
||||
# We're passing back a reference, probably inside a reference.
|
||||
# Let's hold onto it.
|
||||
self.server.object_refs[self.client_id][obj_id] = obj
|
||||
return PickleStub(
|
||||
type="Object",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj_id,
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
elif isinstance(obj, ray.actor.ActorHandle):
|
||||
actor_id = obj._actor_id.binary()
|
||||
if actor_id not in self.server.actor_refs:
|
||||
# We're passing back a handle, probably inside a reference.
|
||||
self.actor_refs[actor_id] = obj
|
||||
if actor_id not in self.actor_owners[self.client_id]:
|
||||
self.actor_owners[self.client_id].add(actor_id)
|
||||
return PickleStub(
|
||||
type="Actor",
|
||||
client_id=self.client_id,
|
||||
ref_id=obj._actor_id.binary(),
|
||||
name=None,
|
||||
baseline_options=None,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ClientUnpickler(pickle.Unpickler):
|
||||
def __init__(self, server, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.server = server
|
||||
|
||||
def persistent_load(self, pid):
|
||||
assert isinstance(pid, PickleStub)
|
||||
if pid.type == "Ray":
|
||||
return ray
|
||||
elif pid.type == "Object":
|
||||
return self.server.object_refs[pid.client_id][pid.ref_id]
|
||||
elif pid.type == "Actor":
|
||||
return self.server.actor_refs[pid.ref_id]
|
||||
elif pid.type == "RemoteFuncSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteFunc":
|
||||
return self.server.lookup_or_register_func(
|
||||
pid.ref_id, pid.client_id, pid.baseline_options)
|
||||
elif pid.type == "RemoteActorSelfReference":
|
||||
return ServerSelfReferenceSentinel()
|
||||
elif pid.type == "RemoteActor":
|
||||
return self.server.lookup_or_register_actor(
|
||||
pid.ref_id, pid.client_id, pid.baseline_options)
|
||||
elif pid.type == "RemoteMethod":
|
||||
actor = self.server.actor_refs[pid.ref_id]
|
||||
return getattr(actor, pid.name)
|
||||
else:
|
||||
raise NotImplementedError("Uncovered client data type")
|
||||
|
||||
|
||||
def dumps_from_server(obj: Any,
|
||||
client_id: str,
|
||||
server_instance: "RayletServicer",
|
||||
protocol=None) -> bytes:
|
||||
with io.BytesIO() as file:
|
||||
sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
|
||||
sp.dump(obj)
|
||||
return file.getvalue()
|
||||
|
||||
|
||||
def loads_from_client(data: bytes,
|
||||
server_instance: "RayletServicer",
|
||||
*,
|
||||
fix_imports=True,
|
||||
encoding="ASCII",
|
||||
errors="strict") -> Any:
|
||||
with disable_client_hook():
|
||||
if isinstance(data, str):
|
||||
raise TypeError("Can't load pickle from unicode string")
|
||||
file = io.BytesIO(data)
|
||||
return ClientUnpickler(
|
||||
server_instance, file, fix_imports=fix_imports,
|
||||
encoding=encoding).load()
|
||||
|
||||
|
||||
def convert_from_arg(pb: "ray_client_pb2.Arg",
|
||||
server: "RayletServicer") -> Any:
|
||||
return loads_from_client(pb.data, server)
|
||||
@@ -0,0 +1,29 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
_current_remote_obj = None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def current_remote(r):
|
||||
global _current_remote_obj
|
||||
remote = _current_remote_obj
|
||||
_current_remote_obj = r
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_current_remote_obj = remote
|
||||
|
||||
|
||||
class ServerSelfReferenceSentinel:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __reduce__(self):
|
||||
global _current_remote_obj
|
||||
if _current_remote_obj is None:
|
||||
return (ServerSelfReferenceSentinel, tuple())
|
||||
return (identity, (_current_remote_obj, ))
|
||||
|
||||
|
||||
def identity(x):
|
||||
return x
|
||||
@@ -2,27 +2,30 @@
|
||||
It implements the Ray API functions that are forwarded through grpc calls
|
||||
to the server.
|
||||
"""
|
||||
import inspect
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import Optional
|
||||
|
||||
import ray.cloudpickle as cloudpickle
|
||||
from ray.util.inspect import is_cython
|
||||
import grpc
|
||||
|
||||
from ray.exceptions import TaskCancelledError
|
||||
import ray.cloudpickle as cloudpickle
|
||||
import ray.core.generated.ray_client_pb2 as ray_client_pb2
|
||||
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
|
||||
from ray.experimental.client.common import convert_to_arg
|
||||
from ray.experimental.client.common import decode_exception
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.common import ClientActorClass
|
||||
from ray.experimental.client.client_pickler import convert_to_arg
|
||||
from ray.experimental.client.client_pickler import dumps_from_client
|
||||
from ray.experimental.client.client_pickler import loads_from_server
|
||||
from ray.experimental.client.common import ClientActorHandle
|
||||
from ray.experimental.client.common import ClientRemoteFunc
|
||||
from ray.experimental.client.common import ClientActorRef
|
||||
from ray.experimental.client.common import ClientObjectRef
|
||||
from ray.experimental.client.dataclient import DataClient
|
||||
from ray.experimental.client.logsclient import LogstreamClient
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -31,34 +34,37 @@ class Worker:
|
||||
def __init__(self,
|
||||
conn_str: str = "",
|
||||
secure: bool = False,
|
||||
metadata: List[Tuple[str, str]] = None,
|
||||
stub=None):
|
||||
metadata: List[Tuple[str, str]] = None):
|
||||
"""Initializes the worker side grpc client.
|
||||
|
||||
Args:
|
||||
stub: custom grpc stub.
|
||||
secure: whether to use SSL secure channel or not.
|
||||
metadata: additional metadata passed in the grpc request headers.
|
||||
"""
|
||||
self.metadata = metadata
|
||||
self.channel = None
|
||||
if stub is None:
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
self._client_id = make_client_id()
|
||||
if secure:
|
||||
credentials = grpc.ssl_channel_credentials()
|
||||
self.channel = grpc.secure_channel(conn_str, credentials)
|
||||
else:
|
||||
self.server = stub
|
||||
self.channel = grpc.insecure_channel(conn_str)
|
||||
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
|
||||
|
||||
self.data_client = DataClient(self.channel, self._client_id)
|
||||
self.reference_count: Dict[bytes, int] = defaultdict(int)
|
||||
|
||||
self.log_client = LogstreamClient(self.channel)
|
||||
self.log_client.set_logstream_level(logging.INFO)
|
||||
self.closed = False
|
||||
|
||||
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
|
||||
to_get = []
|
||||
single = False
|
||||
if isinstance(vals, list):
|
||||
to_get = [x.handle for x in vals]
|
||||
to_get = vals
|
||||
elif isinstance(vals, ClientObjectRef):
|
||||
to_get = [vals.handle]
|
||||
to_get = [vals]
|
||||
single = True
|
||||
else:
|
||||
raise Exception("Can't get something that's not a "
|
||||
@@ -70,15 +76,17 @@ class Worker:
|
||||
out = out[0]
|
||||
return out
|
||||
|
||||
def _get(self, handle: bytes, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
|
||||
def _get(self, ref: ClientObjectRef, timeout: float):
|
||||
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
|
||||
try:
|
||||
data = self.server.GetObject(req, metadata=self.metadata)
|
||||
data = self.data_client.GetObject(req)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
raise e.details()
|
||||
if not data.valid:
|
||||
raise TaskCancelledError(handle)
|
||||
return cloudpickle.loads(data.data)
|
||||
err = cloudpickle.loads(data.error)
|
||||
logger.error(err)
|
||||
raise err
|
||||
return loads_from_server(data.data)
|
||||
|
||||
def put(self, vals):
|
||||
to_put = []
|
||||
@@ -95,26 +103,37 @@ class Worker:
|
||||
return out
|
||||
|
||||
def _put(self, val):
|
||||
data = cloudpickle.dumps(val)
|
||||
if isinstance(val, ClientObjectRef):
|
||||
raise TypeError(
|
||||
"Calling 'put' on an ObjectRef is not allowed "
|
||||
"(similarly, returning an ObjectRef from a remote "
|
||||
"function is not allowed). If you really want to "
|
||||
"do this, you can wrap the ObjectRef in a list and "
|
||||
"call 'put' on it (or return it).")
|
||||
data = dumps_from_client(val, self._client_id)
|
||||
req = ray_client_pb2.PutRequest(data=data)
|
||||
resp = self.server.PutObject(req, metadata=self.metadata)
|
||||
return ClientObjectRef.from_remote_ref(resp.ref)
|
||||
resp = self.data_client.PutObject(req)
|
||||
return ClientObjectRef(resp.id)
|
||||
|
||||
def wait(self,
|
||||
object_refs: List[ClientObjectRef],
|
||||
*,
|
||||
num_returns: int = 1,
|
||||
timeout: float = None
|
||||
timeout: float = None,
|
||||
fetch_local: bool = True
|
||||
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
|
||||
assert isinstance(object_refs, list)
|
||||
if not isinstance(object_refs, list):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
f"got {type(object_refs)}")
|
||||
for ref in object_refs:
|
||||
assert isinstance(ref, ClientObjectRef)
|
||||
if not isinstance(ref, ClientObjectRef):
|
||||
raise TypeError("wait() expected a list of ClientObjectRef, "
|
||||
f"got list containing {type(ref)}")
|
||||
data = {
|
||||
"object_handles": [
|
||||
object_ref.handle for object_ref in object_refs
|
||||
],
|
||||
"object_ids": [object_ref.id for object_ref in object_refs],
|
||||
"num_returns": num_returns,
|
||||
"timeout": timeout if timeout else -1
|
||||
"timeout": timeout if timeout else -1,
|
||||
"client_id": self._client_id,
|
||||
}
|
||||
req = ray_client_pb2.WaitRequest(**data)
|
||||
resp = self.server.WaitObject(req, metadata=self.metadata)
|
||||
@@ -122,41 +141,69 @@ class Worker:
|
||||
# TODO(ameer): improve error/exceptions messages.
|
||||
raise Exception("Client Wait request failed. Reference invalid?")
|
||||
client_ready_object_ids = [
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.ready_object_ids
|
||||
ClientObjectRef(ref) for ref in resp.ready_object_ids
|
||||
]
|
||||
client_remaining_object_ids = [
|
||||
ClientObjectRef.from_remote_ref(ref)
|
||||
for ref in resp.remaining_object_ids
|
||||
ClientObjectRef(ref) for ref in resp.remaining_object_ids
|
||||
]
|
||||
|
||||
return (client_ready_object_ids, client_remaining_object_ids)
|
||||
|
||||
def remote(self, function_or_class, *args, **kwargs):
|
||||
# TODO(barakmich): Arguments to ray.remote
|
||||
# get captured here.
|
||||
if (inspect.isfunction(function_or_class)
|
||||
or is_cython(function_or_class)):
|
||||
return ClientRemoteFunc(function_or_class)
|
||||
elif inspect.isclass(function_or_class):
|
||||
return ClientActorClass(function_or_class)
|
||||
else:
|
||||
raise TypeError("The @ray.remote decorator must be applied to "
|
||||
"either a function or to a class.")
|
||||
|
||||
def call_remote(self, instance, *args, **kwargs):
|
||||
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
|
||||
task = instance._prepare_client_task()
|
||||
for arg in args:
|
||||
pb_arg = convert_to_arg(arg)
|
||||
pb_arg = convert_to_arg(arg, self._client_id)
|
||||
task.args.append(pb_arg)
|
||||
logging.debug("Scheduling %s" % task)
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
return ClientObjectRef.from_remote_ref(ticket.return_ref)
|
||||
for k, v in kwargs.items():
|
||||
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
|
||||
return self._call_schedule_for_task(task)
|
||||
|
||||
def _call_schedule_for_task(
|
||||
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
|
||||
logger.debug("Scheduling %s" % task)
|
||||
task.client_id = self._client_id
|
||||
try:
|
||||
ticket = self.server.Schedule(task, metadata=self.metadata)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details)
|
||||
if not ticket.valid:
|
||||
raise cloudpickle.loads(ticket.error)
|
||||
return ticket.return_ids
|
||||
|
||||
def call_release(self, id: bytes) -> None:
|
||||
if self.closed:
|
||||
return
|
||||
self.reference_count[id] -= 1
|
||||
if self.reference_count[id] == 0:
|
||||
self._release_server(id)
|
||||
del self.reference_count[id]
|
||||
|
||||
def _release_server(self, id: bytes) -> None:
|
||||
if self.data_client is not None:
|
||||
logger.debug(f"Releasing {id}")
|
||||
self.data_client.ReleaseObject(
|
||||
ray_client_pb2.ReleaseRequest(ids=[id]))
|
||||
|
||||
def call_retain(self, id: bytes) -> None:
|
||||
logger.debug(f"Retaining {id.hex()}")
|
||||
self.reference_count[id] += 1
|
||||
|
||||
def close(self):
|
||||
self.server = None
|
||||
self.log_client.close()
|
||||
self.data_client.close()
|
||||
if self.channel:
|
||||
self.channel.close()
|
||||
self.channel = None
|
||||
self.server = None
|
||||
self.closed = True
|
||||
|
||||
def get_actor(self, name: str) -> ClientActorHandle:
|
||||
task = ray_client_pb2.ClientTask()
|
||||
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
|
||||
task.name = name
|
||||
ids = self._call_schedule_for_task(task)
|
||||
assert len(ids) == 1
|
||||
return ClientActorHandle(ClientActorRef(ids[0]))
|
||||
|
||||
def terminate_actor(self, actor: ClientActorHandle,
|
||||
no_restart: bool) -> None:
|
||||
@@ -164,10 +211,11 @@ class Worker:
|
||||
raise ValueError("ray.kill() only supported for actors. "
|
||||
"Got: {}.".format(type(actor)))
|
||||
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
|
||||
term_actor.handle = actor.actor_ref.handle
|
||||
term_actor.id = actor.actor_ref.id
|
||||
term_actor.no_restart = no_restart
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(actor=term_actor)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
@@ -179,11 +227,12 @@ class Worker:
|
||||
"ray.cancel() only supported for non-actor object refs. "
|
||||
f"Got: {type(obj)}.")
|
||||
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
|
||||
term_object.handle = obj.handle
|
||||
term_object.id = obj.id
|
||||
term_object.force = force
|
||||
term_object.recursive = recursive
|
||||
try:
|
||||
term = ray_client_pb2.TerminateRequest(task_object=term_object)
|
||||
term.client_id = self._client_id
|
||||
self.server.Terminate(term)
|
||||
except grpc.RpcError as e:
|
||||
raise decode_exception(e.details())
|
||||
@@ -193,7 +242,9 @@ class Worker:
|
||||
req.type = type
|
||||
resp = self.server.ClusterInfo(req)
|
||||
if resp.WhichOneof("response_type") == "resource_table":
|
||||
return resp.resource_table.table
|
||||
# translate from a proto map to a python dict
|
||||
output_dict = {k: v for k, v in resp.resource_table.table.items()}
|
||||
return output_dict
|
||||
return json.loads(resp.json)
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
@@ -201,3 +252,13 @@ class Worker:
|
||||
return self.get_cluster_info(
|
||||
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
|
||||
return False
|
||||
|
||||
|
||||
def make_client_id() -> str:
|
||||
id = uuid.uuid4()
|
||||
return id.hex
|
||||
|
||||
|
||||
def decode_exception(data) -> Exception:
|
||||
data = base64.standard_b64decode(data)
|
||||
return loads_from_server(data)
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
import ray
|
||||
|
||||
|
||||
def force_spill_objects(object_refs):
|
||||
"""Force spilling objects to external storage.
|
||||
|
||||
Args:
|
||||
object_refs: Object refs of the objects to be
|
||||
spilled.
|
||||
"""
|
||||
core_worker = ray.worker.global_worker.core_worker
|
||||
# Make sure that the values are object refs.
|
||||
for object_ref in object_refs:
|
||||
if not isinstance(object_ref, ray.ObjectRef):
|
||||
raise TypeError(
|
||||
f"Attempting to call `force_spill_objects` on the "
|
||||
f"value {object_ref}, which is not an ray.ObjectRef.")
|
||||
return core_worker.force_spill_objects(object_refs)
|
||||
@@ -157,12 +157,15 @@ class ExternalStorage(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def restore_spilled_objects(self, object_refs: List[ObjectRef],
|
||||
url_with_offset_list: List[str]):
|
||||
url_with_offset_list: List[str]) -> int:
|
||||
"""Restore objects from the external storage.
|
||||
|
||||
Args:
|
||||
object_refs: List of object IDs (note that it is not ref).
|
||||
url_with_offset_list: List of url_with_offset.
|
||||
|
||||
Returns:
|
||||
The total number of bytes restored.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -215,6 +218,7 @@ class FileSystemStorage(ExternalStorage):
|
||||
|
||||
def restore_spilled_objects(self, object_refs: List[ObjectRef],
|
||||
url_with_offset_list: List[str]):
|
||||
total = 0
|
||||
for i in range(len(object_refs)):
|
||||
object_ref = object_refs[i]
|
||||
url_with_offset = url_with_offset_list[i].decode()
|
||||
@@ -228,9 +232,11 @@ class FileSystemStorage(ExternalStorage):
|
||||
metadata_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
buf_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
self._size_check(metadata_len, buf_len, parsed_result.size)
|
||||
total += buf_len
|
||||
metadata = f.read(metadata_len)
|
||||
# read remaining data to our buffer
|
||||
self._put_object_to_store(metadata, buf_len, f, object_ref)
|
||||
return total
|
||||
|
||||
def delete_spilled_objects(self, urls: List[str]):
|
||||
for url in urls:
|
||||
@@ -297,6 +303,7 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
|
||||
def restore_spilled_objects(self, object_refs: List[ObjectRef],
|
||||
url_with_offset_list: List[str]):
|
||||
from smart_open import open
|
||||
total = 0
|
||||
for i in range(len(object_refs)):
|
||||
object_ref = object_refs[i]
|
||||
url_with_offset = url_with_offset_list[i].decode()
|
||||
@@ -315,9 +322,11 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
|
||||
metadata_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
buf_len = int.from_bytes(f.read(8), byteorder="little")
|
||||
self._size_check(metadata_len, buf_len, parsed_result.size)
|
||||
total += buf_len
|
||||
metadata = f.read(metadata_len)
|
||||
# read remaining data to our buffer
|
||||
self._put_object_to_store(metadata, buf_len, f, object_ref)
|
||||
return total
|
||||
|
||||
def delete_spilled_objects(self, urls: List[str]):
|
||||
pass
|
||||
@@ -367,8 +376,8 @@ def restore_spilled_objects(object_refs: List[ObjectRef],
|
||||
object_refs: List of object IDs (note that it is not ref).
|
||||
url_with_offset_list: List of url_with_offset.
|
||||
"""
|
||||
_external_storage.restore_spilled_objects(object_refs,
|
||||
url_with_offset_list)
|
||||
return _external_storage.restore_spilled_objects(object_refs,
|
||||
url_with_offset_list)
|
||||
|
||||
|
||||
def delete_spilled_objects(urls: List[str]):
|
||||
|
||||
@@ -12,6 +12,7 @@ import hashlib
|
||||
import cython
|
||||
import inspect
|
||||
import uuid
|
||||
import ray.ray_constants as ray_constants
|
||||
|
||||
|
||||
ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &)
|
||||
@@ -188,7 +189,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
|
||||
function_name = function.__name__
|
||||
class_name = ""
|
||||
|
||||
pickled_function_hash = hashlib.sha1(pickled_function).hexdigest()
|
||||
pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest(
|
||||
ray_constants.ID_SIZE)
|
||||
|
||||
return cls(module_name, function_name, class_name,
|
||||
pickled_function_hash)
|
||||
@@ -208,7 +210,10 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
|
||||
module_name = target_class.__module__
|
||||
class_name = target_class.__name__
|
||||
# Use a random uuid as function hash to solve actor name conflict.
|
||||
return cls(module_name, "__init__", class_name, str(uuid.uuid4()))
|
||||
return cls(
|
||||
module_name, "__init__", class_name,
|
||||
hashlib.shake_128(
|
||||
uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE))
|
||||
|
||||
@property
|
||||
def module_name(self):
|
||||
@@ -268,14 +273,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
|
||||
Returns:
|
||||
ray.ObjectRef to represent the function descriptor.
|
||||
"""
|
||||
function_id_hash = hashlib.sha1()
|
||||
function_id_hash = hashlib.shake_128()
|
||||
# Include the function module and name in the hash.
|
||||
function_id_hash.update(self.typed_descriptor.ModuleName())
|
||||
function_id_hash.update(self.typed_descriptor.FunctionName())
|
||||
function_id_hash.update(self.typed_descriptor.ClassName())
|
||||
function_id_hash.update(self.typed_descriptor.FunctionHash())
|
||||
# Compute the function ID.
|
||||
function_id = function_id_hash.digest()
|
||||
function_id = function_id_hash.digest(ray_constants.ID_SIZE)
|
||||
return ray.FunctionID(function_id)
|
||||
|
||||
def is_actor_method(self):
|
||||
|
||||
@@ -179,9 +179,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
c_bool plasma_objects_only)
|
||||
CRayStatus Contains(const CObjectID &object_id, c_bool *has_object)
|
||||
CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects,
|
||||
int64_t timeout_ms, c_vector[c_bool] *results)
|
||||
int64_t timeout_ms, c_vector[c_bool] *results,
|
||||
c_bool fetch_local)
|
||||
CRayStatus Delete(const c_vector[CObjectID] &object_ids,
|
||||
c_bool local_only, c_bool delete_creating_tasks)
|
||||
c_bool local_only)
|
||||
CRayStatus TriggerGlobalGC()
|
||||
c_string MemoryUsageString()
|
||||
|
||||
@@ -232,7 +233,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
|
||||
(CRayStatus() nogil) check_signals
|
||||
(void() nogil) gc_collect
|
||||
(c_vector[c_string](const c_vector[CObjectID] &) nogil) spill_objects
|
||||
(void(
|
||||
(int64_t(
|
||||
const c_vector[CObjectID] &,
|
||||
const c_vector[c_string] &) nogil) restore_spilled_objects
|
||||
(void(
|
||||
|
||||
@@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize):
|
||||
raise TypeError("Unsupported type: " + str(type(b)))
|
||||
if len(b) != size:
|
||||
raise ValueError("ID string needs to have length " +
|
||||
str(size))
|
||||
str(size) + ", got " + str(len(b)))
|
||||
|
||||
|
||||
cdef extern from "ray/common/constants.h" nogil:
|
||||
|
||||
@@ -37,7 +37,7 @@ def memory_summary():
|
||||
return reply.memory_summary
|
||||
|
||||
|
||||
def free(object_refs, local_only=False, delete_creating_tasks=False):
|
||||
def free(object_refs, local_only=False):
|
||||
"""Free a list of IDs from the in-process and plasma object stores.
|
||||
|
||||
This function is a low-level API which should be used in restricted
|
||||
@@ -59,8 +59,6 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
|
||||
object_refs (List[ObjectRef]): List of object refs to delete.
|
||||
local_only (bool): Whether only deleting the list of objects in local
|
||||
object store or all object stores.
|
||||
delete_creating_tasks (bool): Whether also delete the object creating
|
||||
tasks.
|
||||
"""
|
||||
worker = ray.worker.global_worker
|
||||
|
||||
@@ -83,5 +81,4 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
|
||||
if len(object_refs) == 0:
|
||||
return
|
||||
|
||||
worker.core_worker.free_objects(object_refs, local_only,
|
||||
delete_creating_tasks)
|
||||
worker.core_worker.free_objects(object_refs, local_only)
|
||||
|
||||
@@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# The groups are worker id, job id, and pid.
|
||||
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)")
|
||||
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)")
|
||||
|
||||
|
||||
class LogFileInfo:
|
||||
|
||||
+19
-15
@@ -15,11 +15,14 @@ from ray.autoscaler._private.constants import AUTOSCALER_UPDATE_INTERVAL_S
|
||||
from ray.autoscaler._private.load_metrics import LoadMetrics
|
||||
from ray.autoscaler._private.constants import \
|
||||
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
|
||||
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS
|
||||
import ray.gcs_utils
|
||||
import ray.utils
|
||||
import ray.ray_constants as ray_constants
|
||||
from ray.ray_logging import setup_component_logger
|
||||
from ray._raylet import GlobalStateAccessor
|
||||
from ray.experimental.internal_kv import _internal_kv_put, \
|
||||
_internal_kv_initialized
|
||||
|
||||
import redis
|
||||
|
||||
@@ -65,11 +68,7 @@ def parse_resource_demands(resource_load_by_shape):
|
||||
except Exception:
|
||||
logger.exception("Failed to parse resource demands.")
|
||||
|
||||
# Bound the total number of bundles to 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE.
|
||||
# This guarantees the resource demand scheduler bin packing algorithm takes
|
||||
# a reasonable amount of time to run.
|
||||
return waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE], \
|
||||
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
|
||||
return waiting_bundles, infeasible_bundles
|
||||
|
||||
|
||||
class Monitor:
|
||||
@@ -184,14 +183,8 @@ class Monitor:
|
||||
data: a resource request as JSON, e.g. {"CPU": 1}
|
||||
"""
|
||||
|
||||
if not self.autoscaler:
|
||||
return
|
||||
|
||||
try:
|
||||
self.autoscaler.request_resources(json.loads(data))
|
||||
except Exception:
|
||||
# We don't want this to kill the monitor.
|
||||
traceback.print_exc()
|
||||
resource_request = json.loads(data)
|
||||
self.load_metrics.set_resource_requests(resource_request)
|
||||
|
||||
def process_messages(self, max_messages=10000):
|
||||
"""Process all messages ready in the subscription channels.
|
||||
@@ -257,12 +250,23 @@ class Monitor:
|
||||
|
||||
# Handle messages from the subscription channels.
|
||||
while True:
|
||||
self.update_raylet_map()
|
||||
self.update_load_metrics()
|
||||
status = {
|
||||
"load_metrics_report": self.load_metrics.summary()._asdict()
|
||||
}
|
||||
|
||||
# Process autoscaling actions
|
||||
if self.autoscaler:
|
||||
# Only used to update the load metrics for the autoscaler.
|
||||
self.update_raylet_map()
|
||||
self.update_load_metrics()
|
||||
self.autoscaler.update()
|
||||
status[
|
||||
"autoscaler_report"] = self.autoscaler.summary()._asdict()
|
||||
|
||||
as_json = json.dumps(status)
|
||||
if _internal_kv_initialized():
|
||||
_internal_kv_put(
|
||||
DEBUG_AUTOSCALING_STATUS, as_json, overwrite=True)
|
||||
|
||||
# Process a round of messages.
|
||||
self.process_messages()
|
||||
|
||||
@@ -19,7 +19,7 @@ def env_bool(key, default):
|
||||
return default
|
||||
|
||||
|
||||
ID_SIZE = 20
|
||||
ID_SIZE = 28
|
||||
|
||||
# The default maximum number of bytes to allocate to the object store unless
|
||||
# overridden by the user.
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
from typing import Callable
|
||||
|
||||
import ray
|
||||
from ray.utils import binary_to_hex
|
||||
|
||||
@@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args,
|
||||
# logger to add a newline at the end of string.
|
||||
handler.terminator = ""
|
||||
return logger
|
||||
|
||||
|
||||
class WorkerStandardStreamDispatcher:
|
||||
def __init__(self):
|
||||
self.handlers = []
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def add_handler(self, name: str, handler: Callable) -> None:
|
||||
with self._lock:
|
||||
self.handlers.append((name, handler))
|
||||
|
||||
def remove_handler(self, name: str) -> None:
|
||||
with self._lock:
|
||||
new_handlers = [pair for pair in self.handlers if pair[0] != name]
|
||||
self.handlers = new_handlers
|
||||
|
||||
def emit(self, data):
|
||||
with self._lock:
|
||||
for pair in self.handlers:
|
||||
_, handle = pair
|
||||
handle(data)
|
||||
|
||||
|
||||
global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher()
|
||||
|
||||
+6
-26
@@ -2,17 +2,15 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from ray._private.ray_microbenchmark_helpers import timeit
|
||||
from ray._private.ray_client_microbenchmark import (main as
|
||||
client_microbenchmark_main)
|
||||
import numpy as np
|
||||
import multiprocessing
|
||||
import ray
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Only run tests matching this filter pattern.
|
||||
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
|
||||
|
||||
|
||||
@ray.remote(num_cpus=0)
|
||||
class Actor:
|
||||
@@ -71,27 +69,6 @@ def small_value_batch(n):
|
||||
return 0
|
||||
|
||||
|
||||
def timeit(name, fn, multiplier=1):
|
||||
if filter_pattern not in name:
|
||||
return
|
||||
# warmup
|
||||
start = time.time()
|
||||
while time.time() - start < 1:
|
||||
fn()
|
||||
# real run
|
||||
stats = []
|
||||
for _ in range(4):
|
||||
start = time.time()
|
||||
count = 0
|
||||
while time.time() - start < 2:
|
||||
fn()
|
||||
count += 1
|
||||
end = time.time()
|
||||
stats.append(multiplier * count / (end - start))
|
||||
print(name, "per second", round(np.mean(stats), 2), "+-",
|
||||
round(np.std(stats), 2))
|
||||
|
||||
|
||||
def check_optimized_build():
|
||||
if not ray._raylet.OPTIMIZED:
|
||||
msg = ("WARNING: Unoptimized build! "
|
||||
@@ -277,6 +254,9 @@ def main():
|
||||
ray.get([async_actor_work.remote(a) for _ in range(m)])
|
||||
|
||||
timeit("n:n async-actor calls async", async_actor_multi, m * n)
|
||||
ray.shutdown()
|
||||
|
||||
client_microbenchmark_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from telnetlib import Telnet
|
||||
import time
|
||||
import urllib
|
||||
import urllib.parse
|
||||
@@ -172,8 +171,7 @@ def continue_debug_session():
|
||||
ray.experimental.internal_kv._internal_kv_del(key)
|
||||
return
|
||||
host, port = session["pdb_address"].split(":")
|
||||
with Telnet(host, int(port)) as tn:
|
||||
tn.interact()
|
||||
ray.util.rpdb.connect_pdb_client(host, int(port))
|
||||
ray.experimental.internal_kv._internal_kv_del(key)
|
||||
continue_debug_session()
|
||||
return
|
||||
@@ -215,8 +213,7 @@ def debug(address):
|
||||
ray.experimental.internal_kv._internal_kv_get(
|
||||
active_sessions[index]))
|
||||
host, port = session["pdb_address"].split(":")
|
||||
with Telnet(host, int(port)) as tn:
|
||||
tn.interact()
|
||||
ray.util.rpdb.connect_pdb_client(host, int(port))
|
||||
|
||||
|
||||
@cli.command()
|
||||
|
||||
@@ -74,7 +74,8 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
new_class_id = pickle.dumps(pickle.loads(class_id))
|
||||
if new_class_id == class_id:
|
||||
# We appear to have reached a fix point, so use this as the ID.
|
||||
return hashlib.sha1(new_class_id).digest()
|
||||
return hashlib.shake_128(new_class_id).digest(
|
||||
ray_constants.ID_SIZE)
|
||||
class_id = new_class_id
|
||||
|
||||
# We have not reached a fixed point, so we may end up with a different
|
||||
@@ -82,7 +83,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
|
||||
# same class definition being exported many many times.
|
||||
logger.warning(
|
||||
f"WARNING: Could not produce a deterministic class ID for class {cls}")
|
||||
return hashlib.sha1(new_class_id).digest()
|
||||
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
|
||||
|
||||
|
||||
def object_ref_deserializer(reduced_obj_ref, owner_address):
|
||||
|
||||
@@ -119,6 +119,14 @@ py_test(
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "test_imported_backend",
|
||||
size = "small",
|
||||
srcs = serve_tests_srcs,
|
||||
tags = ["exclusive"],
|
||||
deps = [":serve_lib"],
|
||||
)
|
||||
|
||||
|
||||
# Runs test_api and test_failure with injected failures in the controller.
|
||||
# TODO(simon): Tests are disabled until #11683 is fixed.
|
||||
|
||||
+93
-28
@@ -4,22 +4,41 @@ import time
|
||||
from functools import wraps
|
||||
import os
|
||||
from uuid import UUID
|
||||
import threading
|
||||
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
|
||||
|
||||
from ray.serve.context import TaskContext
|
||||
|
||||
import ray
|
||||
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
|
||||
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
|
||||
from ray.serve.controller import ServeController
|
||||
from ray.serve.handle import RayServeHandle
|
||||
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
|
||||
from ray.serve.utils import (block_until_http_ready, format_actor_name,
|
||||
get_random_letters, logger, get_conda_env_dir)
|
||||
from ray.serve.exceptions import RayServeException
|
||||
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
|
||||
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
|
||||
HTTPConfig)
|
||||
from ray.serve.env import CondaEnv
|
||||
from ray.serve.router import RequestMetadata, Router
|
||||
from ray.actor import ActorHandle
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union
|
||||
|
||||
_INTERNAL_CONTROLLER_NAME = None
|
||||
|
||||
global_async_loop = None
|
||||
|
||||
|
||||
def create_or_get_async_loop_in_thread():
|
||||
global global_async_loop
|
||||
if global_async_loop is None:
|
||||
global_async_loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(
|
||||
daemon=True,
|
||||
target=global_async_loop.run_forever,
|
||||
)
|
||||
thread.start()
|
||||
return global_async_loop
|
||||
|
||||
|
||||
def _set_internal_controller_name(name):
|
||||
global _INTERNAL_CONTROLLER_NAME
|
||||
@@ -36,6 +55,36 @@ def _ensure_connected(f: Callable) -> Callable:
|
||||
return check
|
||||
|
||||
|
||||
class ThreadProxiedRouter:
|
||||
def __init__(self, controller_handle, sync: bool):
|
||||
self.router = Router(controller_handle)
|
||||
|
||||
if sync:
|
||||
self.async_loop = create_or_get_async_loop_in_thread()
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.router.setup_in_async_loop(),
|
||||
self.async_loop,
|
||||
)
|
||||
else:
|
||||
self.async_loop = asyncio.get_event_loop()
|
||||
self.async_loop.create_task(self.router.setup_in_async_loop())
|
||||
|
||||
def _remote(self, endpoint_name, handle_options, request_data,
|
||||
kwargs) -> Coroutine:
|
||||
request_metadata = RequestMetadata(
|
||||
get_random_letters(10), # Used for debugging.
|
||||
endpoint_name,
|
||||
TaskContext.Python,
|
||||
call_method=handle_options.method_name,
|
||||
shard_key=handle_options.shard_key,
|
||||
http_method=handle_options.http_method,
|
||||
http_headers=handle_options.http_headers,
|
||||
)
|
||||
coro = self.router.assign_request(request_metadata, request_data,
|
||||
**kwargs)
|
||||
return coro
|
||||
|
||||
|
||||
class Client:
|
||||
def __init__(self,
|
||||
controller: ActorHandle,
|
||||
@@ -45,15 +94,10 @@ class Client:
|
||||
self._controller_name = controller_name
|
||||
self._detached = detached
|
||||
self._shutdown = False
|
||||
self._http_host, self._http_port = ray.get(
|
||||
controller.get_http_config.remote())
|
||||
self._http_config = ray.get(controller.get_http_config.remote())
|
||||
|
||||
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
|
||||
# mostly grow in size, it will only shrink when user calls the
|
||||
# .remove_endpoint method. This is fine because we expect the number of
|
||||
# endpoints to be fairly small. However, in case this dictionary does
|
||||
# grow very big, we can replace it with a LRU cache instead.
|
||||
self._handle_cache: Dict[str, ActorHandle] = dict()
|
||||
self._sync_proxied_router = None
|
||||
self._async_proxied_router = None
|
||||
|
||||
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
|
||||
# when the interpreter is exiting so we can't rely on __del__ (it
|
||||
@@ -65,6 +109,18 @@ class Client:
|
||||
|
||||
atexit.register(shutdown_serve_client)
|
||||
|
||||
def _get_proxied_router(self, sync: bool):
|
||||
if sync:
|
||||
if self._sync_proxied_router is None:
|
||||
self._sync_proxied_router = ThreadProxiedRouter(
|
||||
self._controller, sync=True)
|
||||
return self._sync_proxied_router
|
||||
else:
|
||||
if self._async_proxied_router is None:
|
||||
self._async_proxied_router = ThreadProxiedRouter(
|
||||
self._controller, sync=False)
|
||||
return self._async_proxied_router
|
||||
|
||||
def __del__(self):
|
||||
if not self._detached:
|
||||
logger.debug("Shutting down Ray Serve because client went out of "
|
||||
@@ -181,8 +237,8 @@ class Client:
|
||||
num_cpus=0, resources={
|
||||
node_id: 0.01
|
||||
}).remote(
|
||||
"http://{}:{}/-/routes".format(self._http_host,
|
||||
self._http_port),
|
||||
"http://{}:{}/-/routes".format(self._http_config.host,
|
||||
self._http_config.port),
|
||||
check_ready=check_ready,
|
||||
timeout=HTTP_PROXY_TIMEOUT)
|
||||
futures.append(future)
|
||||
@@ -198,8 +254,6 @@ class Client:
|
||||
|
||||
Does not delete any associated backends.
|
||||
"""
|
||||
if endpoint in self._handle_cache:
|
||||
del self._handle_cache[endpoint]
|
||||
self._get_result(self._controller.delete_endpoint.remote(endpoint))
|
||||
|
||||
@_ensure_connected
|
||||
@@ -410,10 +464,11 @@ class Client:
|
||||
proportion))
|
||||
|
||||
@_ensure_connected
|
||||
def get_handle(self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> RayServeHandle:
|
||||
def get_handle(
|
||||
self,
|
||||
endpoint_name: str,
|
||||
missing_ok: Optional[bool] = False,
|
||||
sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
|
||||
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
|
||||
|
||||
Args:
|
||||
@@ -433,14 +488,26 @@ class Client:
|
||||
|
||||
if asyncio.get_event_loop().is_running() and sync:
|
||||
logger.warning(
|
||||
"You are retrieving a ServeHandle inside an asyncio loop. "
|
||||
"You are retrieving a sync handle inside an asyncio loop. "
|
||||
"Try getting client.get_handle(.., sync=False) to get better "
|
||||
"performance.")
|
||||
"performance. Learn more at https://docs.ray.io/en/master/"
|
||||
"serve/advanced.html#sync-and-async-handles")
|
||||
|
||||
if endpoint_name not in self._handle_cache:
|
||||
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
|
||||
self._handle_cache[endpoint_name] = handle
|
||||
return self._handle_cache[endpoint_name]
|
||||
if not asyncio.get_event_loop().is_running() and not sync:
|
||||
logger.warning(
|
||||
"You are retrieving an async handle outside an asyncio loop. "
|
||||
"You should make sure client.get_handle is called inside a "
|
||||
"running event loop. Or call client.get_handle(.., sync=True) "
|
||||
"to create sync handle. Learn more at https://docs.ray.io/en/"
|
||||
"master/serve/advanced.html#sync-and-async-handles")
|
||||
|
||||
if sync:
|
||||
handle = RayServeSyncHandle(
|
||||
self._get_proxied_router(sync=sync), endpoint_name)
|
||||
else:
|
||||
handle = RayServeHandle(
|
||||
self._get_proxied_router(sync=sync), endpoint_name)
|
||||
return handle
|
||||
|
||||
|
||||
def start(detached: bool = False,
|
||||
@@ -492,9 +559,7 @@ def start(detached: bool = False,
|
||||
max_task_retries=-1,
|
||||
).remote(
|
||||
controller_name,
|
||||
http_host,
|
||||
http_port,
|
||||
http_middlewares,
|
||||
HTTPConfig(http_host, http_port, http_middlewares),
|
||||
detached=detached)
|
||||
|
||||
if http_host is not None:
|
||||
|
||||
@@ -186,10 +186,10 @@ class RayServeReplica:
|
||||
"backend_replica_starts",
|
||||
description=("The number of time this replica "
|
||||
"has been restarted due to failure."),
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.restart_counter.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.queuing_latency_tracker = metrics.Histogram(
|
||||
@@ -198,39 +198,39 @@ class RayServeReplica:
|
||||
"The latency for queries waiting in the replica's queue "
|
||||
"waiting to be processed or batched."),
|
||||
boundaries=DEFAULT_LATENCY_BUCKET_MS,
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.queuing_latency_tracker.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.processing_latency_tracker = metrics.Histogram(
|
||||
"backend_processing_latency_ms",
|
||||
description="The latency for queries to be processed",
|
||||
boundaries=DEFAULT_LATENCY_BUCKET_MS,
|
||||
tag_keys=("backend", "replica_tag", "batch_size"))
|
||||
tag_keys=("backend", "replica", "batch_size"))
|
||||
self.processing_latency_tracker.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.num_queued_items = metrics.Gauge(
|
||||
"replica_queued_queries",
|
||||
description=("Current number of queries queued in the "
|
||||
"the backend replicas"),
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.num_queued_items.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.num_processing_items = metrics.Gauge(
|
||||
"replica_processing_queries",
|
||||
description="Current number of queries being processed",
|
||||
tag_keys=("backend", "replica_tag"))
|
||||
tag_keys=("backend", "replica"))
|
||||
self.num_processing_items.set_default_tags({
|
||||
"backend": self.backend_tag,
|
||||
"replica_tag": self.replica_tag
|
||||
"replica": self.replica_tag
|
||||
})
|
||||
|
||||
self.restart_counter.record(1)
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from ray.serve.utils import import_class
|
||||
|
||||
|
||||
class ImportedBackend:
|
||||
"""Factory for a class that will dynamically import a backend class.
|
||||
|
||||
This is intended to be used when the source code for a backend is
|
||||
installed in the worker environment but not the driver.
|
||||
|
||||
Intended usage:
|
||||
>>> client = serve.connect()
|
||||
>>> client.create_backend("b", ImportedBackend("module.Class"), *args)
|
||||
|
||||
This will import module.Class on the worker and proxy all relevant methods
|
||||
to it.
|
||||
"""
|
||||
|
||||
def __new__(cls, class_path):
|
||||
class ImportedBackend:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.wrapped = import_class(class_path)(*args, **kwargs)
|
||||
|
||||
def reconfigure(self, *args, **kwargs):
|
||||
# NOTE(edoakes): we check that the reconfigure method is
|
||||
# present if the user specifies a user_config, so we need to
|
||||
# proxy it manually.
|
||||
return self.wrapped.reconfigure(*args, **kwargs)
|
||||
|
||||
def __getattr__(self, attr):
|
||||
"""Proxy all other methods to the wrapper class."""
|
||||
return getattr(self.wrapped, attr)
|
||||
|
||||
return ImportedBackend
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user