Merge branch 'master' into py39

This commit is contained in:
Akash Patel
2020-12-24 13:13:30 -05:00
committed by GitHub
407 changed files with 14044 additions and 12568 deletions
+2 -46
View File
@@ -20,50 +20,6 @@ before_install:
matrix:
include:
- os: linux
env:
- PYTHON=3.6 SMALL_AND_LARGE_TESTS=1 RAY_ENABLE_NEW_SCHEDULER=1
- PYTHONWARNINGS=ignore
- RAY_DEFAULT_BUILD=1
- RAY_CYTHON_EXAMPLES=1
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED
before_script:
- . ./ci/travis/ci.sh build
script:
# bazel python tests. This should be run last to keep its logs at the end of travis logs.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,-medium_size_python_tests_a_to_j,-medium_size_python_tests_k_to_z,-new_scheduler_broken python/ray/tests/...; fi
- os: linux
env:
- PYTHON=3.6 MEDIUM_TESTS_A_TO_J=1 RAY_ENABLE_NEW_SCHEDULER=1
- PYTHONWARNINGS=ignore
- RAY_DEFAULT_BUILD=1
- RAY_CYTHON_EXAMPLES=1
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED
before_script:
- . ./ci/travis/ci.sh build
script:
# bazel python tests for medium size tests. Used for parallelization.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j,-new_scheduler_broken python/ray/tests/...; fi
- os: linux
env:
- PYTHON=3.6 MEDIUM_TESTS_K_TO_Z=1 RAY_ENABLE_NEW_SCHEDULER=1
- PYTHONWARNINGS=ignore
- RAY_DEFAULT_BUILD=1
- RAY_CYTHON_EXAMPLES=1
- RAY_USE_RANDOM_PORTS=1
install:
- . ./ci/travis/ci.sh init RAY_CI_SERVE_AFFECTED,RAY_CI_TUNE_AFFECTED,RAY_CI_PYTHON_AFFECTED,RAY_CI_DASHBOARD_AFFECTED
before_script:
- . ./ci/travis/ci.sh build
script:
# bazel python tests for medium size tests. Used for parallelization.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_k_to_z,-new_scheduler_broken python/ray/tests/...; fi
- os: linux
env:
- PYTHON=3.6 SMALL_AND_LARGE_TESTS=1
@@ -90,6 +46,7 @@ matrix:
script:
# bazel python tests for medium size tests. Used for parallelization.
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,medium_size_python_tests_a_to_j python/ray/tests/...; fi
- if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only,client_tests --test_env=RAY_CLIENT_MODE=1 python/ray/tests/...; fi
- os: linux
env:
@@ -219,7 +176,7 @@ matrix:
- . ./ci/travis/ci.sh init RAY_CI_MACOS_WHEELS_AFFECTED,RAY_CI_JAVA_AFFECTED,RAY_CI_STREAMING_JAVA_AFFECTED
before_script:
- brew tap adoptopenjdk/openjdk
- brew cask install adoptopenjdk8
- brew install --cask adoptopenjdk8
- export JAVA_HOME=/Library/Java/JavaVirtualMachines/adoptopenjdk-8.jdk/Contents/Home
- java -version
- . ./ci/travis/ci.sh build
@@ -577,4 +534,3 @@ deploy:
repo: ray-project/ray
branch: master
condition: $MULTIPLATFORM_JARS = 1 || $MAC_JARS = 1 || $LINUX_JARS = 1
+3 -109
View File
@@ -880,7 +880,7 @@ cc_test(
)
cc_test(
name = "local_placement_group_manager_test",
name = "placement_group_resource_manager_test",
srcs = ["src/ray/raylet/placement_group_resource_manager_test.cc"],
copts = COPTS,
deps = [
@@ -956,8 +956,8 @@ cc_test(
)
cc_test(
name = "task_dependency_manager_test",
srcs = ["src/ray/raylet/task_dependency_manager_test.cc"],
name = "dependency_manager_test",
srcs = ["src/ray/raylet/dependency_manager_test.cc"],
copts = COPTS,
deps = [
":raylet_lib",
@@ -1020,7 +1020,6 @@ cc_test(
cc_library(
name = "gcs_test_util_lib",
hdrs = [
"src/ray/gcs/test/accessor_test_base.h",
"src/ray/gcs/test/gcs_test_util.h",
],
copts = COPTS,
@@ -1621,111 +1620,6 @@ cc_library(
],
)
# TODO(micafan) Support test group in future. Use test group we can run all gcs test once.
cc_test(
name = "redis_gcs_client_test",
srcs = ["src/ray/gcs/test/redis_gcs_client_test.cc"],
args = [
"$(location redis-server)",
"$(location redis-cli)",
"$(location libray_redis_module.so)",
],
copts = COPTS,
data = [
"//:libray_redis_module.so",
"//:redis-cli",
"//:redis-server",
],
deps = [
":gcs",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "redis_actor_info_accessor_test",
srcs = ["src/ray/gcs/test/redis_actor_info_accessor_test.cc"],
args = [
"$(location redis-server)",
"$(location redis-cli)",
"$(location libray_redis_module.so)",
],
copts = COPTS,
data = [
"//:libray_redis_module.so",
"//:redis-cli",
"//:redis-server",
],
deps = [
":gcs",
":gcs_test_util_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "redis_object_info_accessor_test",
srcs = ["src/ray/gcs/test/redis_object_info_accessor_test.cc"],
args = [
"$(location redis-server)",
"$(location redis-cli)",
"$(location libray_redis_module.so)",
],
copts = COPTS,
data = [
"//:libray_redis_module.so",
"//:redis-cli",
"//:redis-server",
],
deps = [
":gcs",
":gcs_test_util_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "redis_job_info_accessor_test",
srcs = ["src/ray/gcs/test/redis_job_info_accessor_test.cc"],
args = [
"$(location redis-server)",
"$(location redis-cli)",
"$(location libray_redis_module.so)",
],
copts = COPTS,
data = [
"//:libray_redis_module.so",
"//:redis-cli",
"//:redis-server",
],
deps = [
":gcs",
":gcs_test_util_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "redis_node_info_accessor_test",
srcs = ["src/ray/gcs/test/redis_node_info_accessor_test.cc"],
args = [
"$(location redis-server)",
"$(location redis-cli)",
"$(location libray_redis_module.so)",
],
copts = COPTS,
data = [
"//:libray_redis_module.so",
"//:redis-cli",
"//:redis-server",
],
deps = [
":gcs",
":gcs_test_util_lib",
"@com_google_googletest//:gtest_main",
],
)
cc_test(
name = "asio_test",
srcs = ["src/ray/gcs/test/asio_test.cc"],
+4 -2
View File
@@ -1,12 +1,14 @@
# py_test_module_list creates a py_test target for each
# Python file in `files`
def py_test_module_list(files, size, deps, extra_srcs, **kwargs):
def py_test_module_list(files, size, deps, extra_srcs, name_suffix="", **kwargs):
for file in files:
# remove .py
name = file[:-3]
name = file[:-3] + name_suffix
main = file
native.py_test(
name = name,
size = size,
main = file,
srcs = extra_srcs + [file],
**kwargs
)
+9 -2
View File
@@ -120,7 +120,6 @@ test_core() {
case "${OSTYPE}" in
msys)
args+=(
-//:redis_gcs_client_test
-//:core_worker_test
-//:event_test
-//:gcs_pub_sub_test
@@ -262,6 +261,11 @@ _bazel_build_before_install() {
bazel build "${target}"
}
_bazel_build_protobuf() {
bazel build "//:install_py_proto"
}
install_ray() {
# TODO(mehrdadn): This function should be unified with the one in python/build-wheel-windows.sh.
(
@@ -296,7 +300,8 @@ build_wheels() {
;;
darwin*)
# This command should be kept in sync with ray/python/README-building-wheels.md.
suppress_output "${WORKSPACE_DIR}"/python/build-wheel-macos.sh
# Remove suppress_output for now to avoid timeout
"${WORKSPACE_DIR}"/python/build-wheel-macos.sh
;;
msys*)
keep_alive "${WORKSPACE_DIR}"/python/build-wheel-windows.sh
@@ -457,6 +462,8 @@ init() {
build() {
if [ "${LINT-}" != 1 ]; then
_bazel_build_before_install
else
_bazel_build_protobuf
fi
if ! need_wheels; then
+12
View File
@@ -16,6 +16,7 @@ template <typename T>
class ObjectRef {
public:
ObjectRef();
~ObjectRef();
ObjectRef(const ObjectID &id);
@@ -46,6 +47,17 @@ ObjectRef<T>::ObjectRef() {}
template <typename T>
ObjectRef<T>::ObjectRef(const ObjectID &id) {
id_ = id;
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.AddLocalReference(id_);
}
}
template <typename T>
ObjectRef<T>::~ObjectRef() {
if (CoreWorkerProcess::IsInitialized()) {
auto &core_worker = CoreWorkerProcess::GetCoreWorker();
core_worker.RemoveLocalReference(id_);
}
}
template <typename T>
+20 -26
View File
@@ -91,16 +91,13 @@ TEST(RayClusterModeTest, FullTest) {
auto r5 = Ray::Task(Plus, r4, r3).Remote();
auto r6 = Ray::Task(Plus, r4, 10).Remote();
///// TODO(ameer/guyang): All the commented code lines below should be
///// uncommented once reference counting is added. Currently the objects
///// are leaking from the object store.
int result5 = *(Ray::Get(r5));
// int result4 = *(Ray::Get(r4));
int result4 = *(Ray::Get(r4));
int result6 = *(Ray::Get(r6));
// int result3 = *(Ray::Get(r3));
int result3 = *(Ray::Get(r3));
EXPECT_EQ(result0, 1);
// EXPECT_EQ(result3, 1);
// EXPECT_EQ(result4, 2);
EXPECT_EQ(result3, 1);
EXPECT_EQ(result4, 2);
EXPECT_EQ(result5, 3);
EXPECT_EQ(result6, 12);
@@ -114,37 +111,34 @@ TEST(RayClusterModeTest, FullTest) {
int result7 = *(Ray::Get(r7));
int result8 = *(Ray::Get(r8));
int result9 = *(Ray::Get(r9));
// int result10 = *(Ray::Get(r10));
int result10 = *(Ray::Get(r10));
EXPECT_EQ(result7, 15);
EXPECT_EQ(result8, 16);
EXPECT_EQ(result9, 19);
// EXPECT_EQ(result10, 27);
EXPECT_EQ(result10, 27);
/// create actor and task function remote call with args passed by reference
// ActorHandle<Counter> actor5 = Ray::Actor(Counter::FactoryCreate, r10, 0).Remote();
ActorHandle<Counter> actor5 = Ray::Actor(Counter::FactoryCreate, 27, 0).Remote();
// auto r11 = actor5.Task(&Counter::Add, r0).Remote();
auto r11 = actor5.Task(&Counter::Add, 1).Remote();
// auto r12 = actor5.Task(&Counter::Add, r11).Remote();
ActorHandle<Counter> actor5 = Ray::Actor(Counter::FactoryCreate, r10, 0).Remote();
auto r11 = actor5.Task(&Counter::Add, r0).Remote();
auto r12 = actor5.Task(&Counter::Add, r11).Remote();
auto r13 = actor5.Task(&Counter::Add, r10).Remote();
auto r14 = actor5.Task(&Counter::Add, r13).Remote();
// auto r15 = Ray::Task(Plus, r0, r11).Remote();
auto r15 = Ray::Task(Plus, 1, r11).Remote();
auto r15 = Ray::Task(Plus, r0, r11).Remote();
auto r16 = Ray::Task(Plus1, r15).Remote();
// int result12 = *(Ray::Get(r12));
int result12 = *(Ray::Get(r12));
int result14 = *(Ray::Get(r14));
// int result11 = *(Ray::Get(r11));
// int result13 = *(Ray::Get(r13));
int result11 = *(Ray::Get(r11));
int result13 = *(Ray::Get(r13));
int result16 = *(Ray::Get(r16));
// int result15 = *(Ray::Get(r15));
int result15 = *(Ray::Get(r15));
// EXPECT_EQ(result11, 28);
// EXPECT_EQ(result12, 56);
// EXPECT_EQ(result13, 83);
// EXPECT_EQ(result14, 166);
EXPECT_EQ(result14, 110);
// EXPECT_EQ(result15, 29);
EXPECT_EQ(result11, 28);
EXPECT_EQ(result12, 56);
EXPECT_EQ(result13, 83);
EXPECT_EQ(result14, 166);
EXPECT_EQ(result15, 29);
EXPECT_EQ(result16, 30);
Ray::Shutdown();
+2 -1
View File
@@ -77,7 +77,8 @@ class DataOrganizer:
job_workers = {}
node_workers = {}
core_worker_stats = {}
for node_id in DataSource.nodes.keys():
# await inside for loop, so we create a copy of keys().
for node_id in list(DataSource.nodes.keys()):
workers = await cls.get_node_workers(node_id)
for worker in workers:
job_id = worker["jobId"]
@@ -4,7 +4,7 @@ import ray.utils
import ray.new_dashboard.utils as dashboard_utils
import ray.new_dashboard.actor_utils as actor_utils
from ray.new_dashboard.utils import rest_response
from ray.new_dashboard.datacenter import DataOrganizer
from ray.new_dashboard.datacenter import DataOrganizer, DataSource
from ray.core.generated import core_worker_pb2
from ray.core.generated import core_worker_pb2_grpc
@@ -29,6 +29,14 @@ class LogicalViewHead(dashboard_utils.DashboardHeadModule):
message="Fetched actor groups.",
actor_groups=actor_groups)
@routes.get("/logical/actors")
@dashboard_utils.aiohttp_cache
async def get_all_actors(self, req) -> aiohttp.web.Response:
return dashboard_utils.rest_response(
success=True,
message="All actors fetched.",
actors=DataSource.actors)
@routes.get("/logical/kill_actor")
async def kill_actor(self, req) -> aiohttp.web.Response:
try:
@@ -35,7 +35,7 @@ def test_actor_groups(ray_start_with_dashboard):
assert wait_until_server_available(webui_url)
webui_url = format_web_url(webui_url)
timeout_seconds = 5
timeout_seconds = 10
start_time = time.time()
last_ex = None
while True:
@@ -79,6 +79,63 @@ def test_actor_groups(ray_start_with_dashboard):
raise Exception(f"Timed out while testing, {ex_stack}")
def test_actors(disable_aiohttp_cache, ray_start_with_dashboard):
@ray.remote
class Foo:
def __init__(self, num):
self.num = num
def do_task(self):
return self.num
@ray.remote(num_gpus=1)
class InfeasibleActor:
pass
foo_actors = [Foo.remote(4), Foo.remote(5)]
infeasible_actor = InfeasibleActor.remote() # noqa
results = [actor.do_task.remote() for actor in foo_actors] # noqa
webui_url = ray_start_with_dashboard["webui_url"]
assert wait_until_server_available(webui_url)
webui_url = format_web_url(webui_url)
timeout_seconds = 5
start_time = time.time()
last_ex = None
while True:
time.sleep(1)
try:
resp = requests.get(f"{webui_url}/logical/actors")
resp_json = resp.json()
resp_data = resp_json["data"]
actors = resp_data["actors"]
assert len(actors) == 3
one_entry = list(actors.values())[0]
assert "jobId" in one_entry
assert "taskSpec" in one_entry
assert "functionDescriptor" in one_entry["taskSpec"]
assert type(one_entry["taskSpec"]["functionDescriptor"]) is dict
assert "address" in one_entry
assert type(one_entry["address"]) is dict
assert "state" in one_entry
assert "name" in one_entry
assert "numRestarts" in one_entry
assert "pid" in one_entry
all_pids = [entry["pid"] for entry in actors.values()]
assert 0 in all_pids # The infeasible actor
assert len(all_pids) > 1
break
except Exception as ex:
last_ex = ex
finally:
if time.time() > start_time + timeout_seconds:
ex_stack = traceback.format_exception(
type(last_ex), last_ex,
last_ex.__traceback__) if last_ex else []
ex_stack = "".join(ex_stack)
raise Exception(f"Timed out while testing, {ex_stack}")
def test_kill_actor(ray_start_with_dashboard):
@ray.remote
class Actor:
+10 -2
View File
@@ -13,6 +13,7 @@ import ray.new_dashboard.utils as dashboard_utils
import ray._private.services
import ray.utils
from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS,
DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR)
from ray.core.generated import reporter_pb2
from ray.core.generated import reporter_pb2_grpc
@@ -113,13 +114,20 @@ class ReportHead(dashboard_utils.DashboardHeadModule):
"""
aioredis_client = self._dashboard_head.aioredis_client
status = await aioredis_client.hget(DEBUG_AUTOSCALING_STATUS, "value")
legacy_status = await aioredis_client.hget(
DEBUG_AUTOSCALING_STATUS_LEGACY, "value")
formatted_status_string = await aioredis_client.hget(
DEBUG_AUTOSCALING_STATUS, "value")
formatted_status = json.loads(formatted_status_string.decode()
) if formatted_status_string else {}
error = await aioredis_client.hget(DEBUG_AUTOSCALING_ERROR, "value")
return dashboard_utils.rest_response(
success=True,
message="Got cluster status.",
autoscaling_status=status.decode() if status else None,
autoscaling_status=legacy_status.decode()
if legacy_status else None,
autoscaling_error=error.decode() if error else None,
cluster_status=formatted_status if formatted_status else None,
)
async def run(self, server):
@@ -246,7 +246,8 @@ class StatsCollector(dashboard_utils.DashboardHeadModule):
@async_loop_forever(
stats_collector_consts.NODE_STATS_UPDATE_INTERVAL_SECONDS)
async def _update_node_stats(self):
for node_id, stub in self._stubs.items():
# Copy self._stubs to avoid `dictionary changed size during iteration`.
for node_id, stub in list(self._stubs.items()):
node_info = DataSource.nodes.get(node_id)
if node_info["state"] != "ALIVE":
continue
@@ -112,20 +112,16 @@ def test_memory_table(disable_aiohttp_cache, ray_start_with_dashboard):
def check_mem_table():
resp = requests.get(f"{webui_url}/memory/memory_table")
resp_data = resp.json()
if not resp_data["result"]:
return False
assert resp_data["result"]
latest_memory_table = resp_data["data"]["memoryTable"]
summary = latest_memory_table["summary"]
try:
# 1 ref per handle and per object the actor has a ref to
assert summary["totalActorHandles"] == len(actors) * 2
# 1 ref for my_obj
assert summary["totalLocalRefCount"] == 1
return True
except AssertionError:
return False
# 1 ref per handle and per object the actor has a ref to
assert summary["totalActorHandles"] == len(actors) * 2
# 1 ref for my_obj
assert summary["totalLocalRefCount"] == 1
wait_for_condition(check_mem_table, 10)
wait_until_succeeded_without_exception(
check_mem_table, (AssertionError, ), timeout_ms=1000)
def test_get_all_node_details(disable_aiohttp_cache, ray_start_with_dashboard):
+7 -2
View File
@@ -19,7 +19,7 @@ from ray import ray_constants
from ray.test_utils import (format_web_url, wait_for_condition,
wait_until_server_available, run_string_as_driver,
wait_until_succeeded_without_exception)
from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS,
from ray.autoscaler._private.util import (DEBUG_AUTOSCALING_STATUS_LEGACY,
DEBUG_AUTOSCALING_ERROR)
import ray.new_dashboard.consts as dashboard_consts
import ray.new_dashboard.utils as dashboard_utils
@@ -458,11 +458,14 @@ def test_get_cluster_status(ray_start_with_dashboard):
def get_cluster_status():
response = requests.get(f"{webui_url}/api/cluster_status")
response.raise_for_status()
print(response.json())
assert response.json()["result"]
assert "autoscalingStatus" in response.json()["data"]
assert response.json()["data"]["autoscalingStatus"] is None
assert "autoscalingError" in response.json()["data"]
assert response.json()["data"]["autoscalingError"] is None
assert "clusterStatus" in response.json()["data"]
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
wait_until_succeeded_without_exception(get_cluster_status,
(requests.RequestException, ))
@@ -478,7 +481,7 @@ def test_get_cluster_status(ray_start_with_dashboard):
port=int(address[1]),
password=ray_constants.REDIS_DEFAULT_PASSWORD)
client.hset(DEBUG_AUTOSCALING_STATUS, "value", "hello")
client.hset(DEBUG_AUTOSCALING_STATUS_LEGACY, "value", "hello")
client.hset(DEBUG_AUTOSCALING_ERROR, "value", "world")
response = requests.get(f"{webui_url}/api/cluster_status")
@@ -488,6 +491,8 @@ def test_get_cluster_status(ray_start_with_dashboard):
assert response.json()["data"]["autoscalingStatus"] == "hello"
assert "autoscalingError" in response.json()["data"]
assert response.json()["data"]["autoscalingError"] == "world"
assert "clusterStatus" in response.json()["data"]
assert "loadMetricsReport" in response.json()["data"]["clusterStatus"]
def test_immutable_types():
+3 -2
View File
@@ -7,8 +7,9 @@ from ray.new_dashboard.memory_utils import (
NODE_ADDRESS = "127.0.0.1"
IS_DRIVER = True
PID = 1
OBJECT_ID = "7wpsIhgZiBz/////AQAAyAEAAAA="
ACTOR_ID = "fffffffffffffffff66d17ba010000c801000000"
OBJECT_ID = "ZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZmZg=="
ACTOR_ID = "fffffffffffffffffffffffffffffffff66d17ba010000c801000000"
DECODED_ID = decode_object_ref_if_needed(OBJECT_ID)
OBJECT_SIZE = 100
+36
View File
@@ -392,3 +392,39 @@ To get information about the current available resource capacity of your cluster
.. autofunction:: ray.available_resources
:noindex:
Object Spilling
---------------
Ray 1.2.0+ has *beta* support for spilling objects to external storage once the capacity
of the object store is used up. Please file a `GitHub issue <https://github.com/ray-project/ray/issues/>`__
if you encounter any problems with this new feature. Eventually, object spilling will be
enabled by default, but for now you need to enable it manually:
To enable object spilling to the local filesystem (single node clusters only):
.. code-block:: python
ray.init(
_system_config={
"automatic_object_spilling_enabled": True,
"object_spilling_config": json.dumps(
{"type": "filesystem", "params": {"directory_path": "/tmp/spill"}},
)
},
)
To enable object spilling to remote storage (any URI supported by `smart_open <https://pypi.org/project/smart-open/>`__):
.. code-block:: python
ray.init(
_system_config={
"automatic_object_spilling_enabled": True,
"max_io_workers": 4, # More IO workers for remote storage.
"min_spilling_size": 100 * 1024 * 1024, # Spill at least 100MB at a time.
"object_spilling_config": json.dumps(
{"type": "smart_open", "params": {"uri": "s3:///bucket/path"}},
)
},
)
+1
View File
@@ -41,6 +41,7 @@ MOCK_MODULES = [
"horovod",
"horovod.ray",
"kubernetes",
"mlflow",
"mxnet",
"mxnet.model",
"psutil",
Binary file not shown.

After

Width:  |  Height:  |  Size: 3.8 KiB

+106
View File
@@ -89,6 +89,112 @@ The Ray debugger supports the
`same commands as PDB
<https://docs.python.org/3/library/pdb.html#debugger-commands>`_.
Stepping between Ray tasks
--------------------------
You can use the debugger to step between Ray tasks. Let's take the
following recursive function as an example:
.. code-block:: python
import ray
ray.init()
@ray.remote
def fact(n):
if n == 1:
return n
else:
n_id = fact.remote(n - 1)
return n * ray.get(n_id)
ray.util.pdb.set_trace()
result_ref = fact.remote(5)
result = ray.get(result_ref)
After running the program by executing the Python file and calling
``ray debug``, you can select the breakpoint by pressing ``0`` and
enter. This will result in the following output:
.. code-block:: python
Enter breakpoint index or press enter to refresh: 0
> /Users/pcmoritz/tmp/stepping.py(14)<module>()
-> result_ref = fact.remote(5)
(Pdb)
You can jump into the call with the ``remote`` command in Ray's debugger.
Inside the function, print the value of `n` with ``p(n)``, resulting in
the following output:
.. code-block:: python
-> result_ref = fact.remote(5)
(Pdb) remote
*** Connection closed by remote host ***
Continuing pdb session in different process...
--Call--
> /Users/pcmoritz/tmp/stepping.py(5)fact()
-> @ray.remote
(Pdb) ll
5 -> @ray.remote
6 def fact(n):
7 if n == 1:
8 return n
9 else:
10 n_id = fact.remote(n - 1)
11 return n * ray.get(n_id)
(Pdb) p(n)
5
(Pdb)
Now step into the next remote call again with
``remote`` and print `n`. You an now either continue recursing into
the function by calling ``remote`` a few more times, or you can jump
to the location where ``ray.get`` is called on the result by using the
``get`` debugger comand. Use ``get`` again to jump back to the original
call site and use ``p(result)`` to print the result:
.. code-block:: python
Enter breakpoint index or press enter to refresh: 0
> /Users/pcmoritz/tmp/stepping.py(14)<module>()
-> result_ref = fact.remote(5)
(Pdb) remote
*** Connection closed by remote host ***
Continuing pdb session in different process...
--Call--
> /Users/pcmoritz/tmp/stepping.py(5)fact()
-> @ray.remote
(Pdb) p(n)
5
(Pdb) remote
*** Connection closed by remote host ***
Continuing pdb session in different process...
--Call--
> /Users/pcmoritz/tmp/stepping.py(5)fact()
-> @ray.remote
(Pdb) p(n)
4
(Pdb) get
*** Connection closed by remote host ***
Continuing pdb session in different process...
--Return--
> /Users/pcmoritz/tmp/stepping.py(5)fact()->120
-> @ray.remote
(Pdb) get
*** Connection closed by remote host ***
Continuing pdb session in different process...
--Return--
> /Users/pcmoritz/tmp/stepping.py(14)<module>()->None
-> result_ref = fact.remote(5)
(Pdb) p(result)
120
(Pdb)
Post Mortem Debugging
---------------------
+28
View File
@@ -267,6 +267,26 @@ That's it. Let's take a look at an example:
.. literalinclude:: ../../../python/ray/serve/examples/doc/snippet_model_composition.py
.. _serve-sync-async-handles:
Sync and Async Handles
======================
Ray Serve offers two types of ``ServeHandle``. You can use the ``client.get_handle(..., sync=True|False)``
flag to toggle between them.
- When you set ``sync=True`` (the default), a synchronous handle is returned.
Calling ``handle.remote()`` should return a Ray ObjectRef.
- When you set ``sync=False``, an asyncio based handle is returned. You need to
Call it with ``await handle.remote()`` to return a Ray ObjectRef. To use ``await``,
you have to run ``client.get_handle`` and ``handle.remote`` in Python asyncio event loop.
The async handle has performance advantage because it uses asyncio directly; as compared
to the sync handle, which talks to an asyncio event loop in a thread. To learn more about
the reasoning behind these, checkout our `architecture documentation <./architecture.html>`_.
Monitoring
==========
@@ -327,3 +347,11 @@ as shown below.
:mod:`client.create_backend <ray.serve.api.Client.create_backend>` by
default.
The dependencies required in the backend may be different than
the dependencies installed in the driver program (the one running Serve API
calls). In this case, you can use an
:mod:`ImportedBackend <ray.serve.backends.ImportedBackend>` to specify a
backend based on a class that is installed in the Python environment that
the workers will run in. Example:
.. literalinclude:: ../../../python/ray/serve/examples/doc/imported_backend.py
+15 -22
View File
@@ -117,7 +117,7 @@ policies <serve-split-traffic>`, finding the next available replica, and
batching requests together.
When the request arrives in the model, you can access the data similarly to how
you would with HTTP request. Here are some examples how ServeRequest mirrors Flask.Request:
you would with HTTP request. Here are some examples how ServeRequest mirrors Starlette.Request:
.. list-table::
:header-rows: 1
@@ -125,25 +125,25 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla
* - HTTP
- ServeHandle
- | Request
| (Flask.Request and ServeRequest)
| (Starlette.Request and ServeRequest)
* - ``requests.get(..., headers={...})``
- ``handle.options(http_headers={...})``
- ``request.headers``
* - ``requests.post(...)``
- ``handle.options(http_method="POST")``
- ``requests.method``
* - ``request.get(..., json={...})``
- ``request.method``
* - ``requests.get(..., json={...})``
- ``handle.remote({...})``
- ``request.json``
* - ``request.get(..., form={...})``
- ``await request.json()``
* - ``requests.get(..., form={...})``
- ``handle.remote({...})``
- ``request.form``
* - ``request.get(..., params={"a":"b"})``
- ``await request.form()``
* - ``requests.get(..., params={"a":"b"})``
- ``handle.remote(a="b")``
- ``request.args``
* - ``request.get(..., data="long string")``
- ``request.query_params``
* - ``requests.get(..., data="long string")``
- ``handle.remote("long string")``
- ``request.data``
- ``await request.body()``
* - ``N/A``
- ``handle.remote(python_object)``
- ``request.data``
@@ -157,9 +157,9 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla
.. code-block:: python
import flask
import starlette.requests
if isinstance(request, flask.Request):
if isinstance(request, starlette.requests.Request):
print("Request coming from web!")
elif isinstance(request, ServeRequest):
print("Request coming from Python!")
@@ -170,10 +170,10 @@ you would with HTTP request. Here are some examples how ServeRequest mirrors Fla
.. code-block:: python
handle.remote(flask_request)
handle.remote(starlette_request)
In this case, Serve will `not` wrap it in ServeRequest. You can directly
process the request as a ``flask.Request``.
process the request as a ``starlette.requests.Request``.
How fast is Ray Serve?
----------------------
@@ -187,13 +187,6 @@ You can checkout our `microbenchmark instruction <https://github.com/ray-project
to benchmark on your hardware.
Does Ray Serve use Flask?
-------------------------
Flask is only used as a web request object for servable to consume the data.
We actually use the fastest Python web server: `Uvicorn <https://www.uvicorn.org/>`_ as our web server,
alongside with the power of Python asyncio.
**Flask is ONLY the request object that we are using, Uvicorn (not flask) provides the webserver.**
Can I use asyncio along with Ray Serve?
---------------------------------------
Yes! You can make your servable methods ``async def`` and Serve will run them
+3
View File
@@ -33,6 +33,9 @@ Since Serve is built on Ray, it also allows you to scale to many machines, in yo
If you want to try out Serve, join our `community slack <https://forms.gle/9TSdDYUgxYs8SA9e8>`_
and discuss in the #serve channel.
.. note::
Starting with Ray version 1.3.0, Ray Serve backends must take in a Starlette Request object instead of a Flask Request object.
See the `migration guide <https://docs.google.com/document/d/1CG4y5WTTc4G_MRQGyjnb_eZ7GK3G9dUX6TNLKLnKRAc/edit?usp=sharing>`_ for details.
Installation
============
+3 -5
View File
@@ -19,10 +19,8 @@ Backends
Backends define the implementation of your business logic or models that will handle requests when queries come in to :ref:`serve-endpoint`.
In order to support seamless scalability backends can have many replicas, which are individual processes running in the Ray cluster to handle requests.
To define a backend, first you must define the "handler" or the business logic you'd like to respond with.
The handler should take as input a `Flask Request object <https://flask.palletsprojects.com/en/1.1.x/api/?highlight=request#flask.Request>`_.
The handler should return any JSON-serializable object as output. For a more customizable response type, the handler may return a
The handler should take as input a `Starlette Request object <https://www.starlette.io/requests/>`_ and return any JSON-serializable object as output. For a more customizable response type, the handler may return a
`Starlette Response object <https://www.starlette.io/responses/>`_.
In the future, Ray Serve will support `Starlette Request objects <https://www.starlette.io/requests/>`_ as input as well.
A backend is defined using :mod:`client.create_backend <ray.serve.api.Client.create_backend>`, and the implementation can be defined as either a function or a class.
Use a function when your response is stateless and a class when you might need to maintain some state (like a model).
@@ -32,7 +30,7 @@ A backend consists of a number of *replicas*, which are individual copies of the
.. code-block:: python
def handle_request(flask_request):
def handle_request(starlette_request):
return "hello world"
class RequestHandler:
@@ -40,7 +38,7 @@ A backend consists of a number of *replicas*, which are individual copies of the
def __init__(self, msg):
self.msg = msg
def __call__(self, flask_request):
def __call__(self, starlette_request):
return self.msg
client.create_backend("simple_backend", handle_request)
+4 -1
View File
@@ -23,7 +23,7 @@ Handle API
:members: remote, options
When calling from Python, the backend implementation will receive ``ServeRequest``
objects instead of Flask requests.
objects instead of Starlette requests.
.. autoclass:: ray.serve.utils.ServeRequest
:members:
@@ -31,3 +31,6 @@ objects instead of Flask requests.
Batching Requests
-----------------
.. autofunction:: ray.serve.accept_batch
Built-in Backends
.. autoclass:: ray.serve.backends.ImportedBackend
+5 -5
View File
@@ -30,13 +30,13 @@ You can use the ``@serve.accept_batch`` decorator to annotate a function or a cl
This annotation is needed because batched backends have different APIs compared
to single request backends. In a batched backend, the inputs are a list of values.
For single query backend, the input type is a single Flask request or
For single query backend, the input type is a single Starlette request or
:mod:`ServeRequest <ray.serve.utils.ServeRequest>`:
.. code-block:: python
def single_request(
request: Union[Flask.Request, ServeRequest],
request: Union[starlette.requests.Request, ServeRequest],
):
pass
@@ -47,7 +47,7 @@ types:
@serve.accept_batch
def batched_request(
request: List[Union[Flask.Request, ServeRequest]],
request: List[Union[starlette.requests.Request, ServeRequest]],
):
pass
@@ -84,8 +84,8 @@ Ray Serve was able to evaluate them in batches.
What if you want to evaluate a whole batch in Python? Ray Serve allows you to send
queries via the Python API. A batch of queries can either come from the web server
or the Python API. Requests coming from the Python API will have the similar API
as Flask.Request. See more on the API :ref:`here<serve-handle-explainer>`.
or the Python API. Requests coming from the Python API will have a similar API
to Starlette Request. See more on the API :ref:`here<serve-handle-explainer>`.
.. literalinclude:: ../../../../python/ray/serve/examples/doc/tutorial_batch.py
:start-after: __doc_define_servable_v1_begin__
+9 -3
View File
@@ -70,6 +70,11 @@ Take a look at any of the below tutorials to get started with Tune.
:figure: /images/wandb_logo.png
:description: :doc:`Track your experiment process with the Weights & Biases tools <tune-wandb>`
.. customgalleryitem::
:tooltip: Use MLFlow with Ray Tune.
:figure: /images/mlflow.png
:description: :doc:`Log and track your hyperparameter sweep with MLFlow Tracking & AutoLogging <tune-mlflow>`
.. raw:: html
@@ -81,12 +86,13 @@ Take a look at any of the below tutorials to get started with Tune.
tune-tutorial.rst
tune-advanced-tutorial.rst
tune-lifecycle.rst
tune-distributed.rst
tune-sklearn.rst
tune-lifecycle.rst
tune-mlflow.rst
tune-pytorch-cifar.rst
tune-pytorch-lightning.rst
tune-serve-integration-mnist.rst
tune-sklearn.rst
tune-xgboost.rst
tune-wandb.rst
@@ -156,4 +162,4 @@ Check out:
.. _tune-faq:
.. include:: _faq.rst
.. include:: _faq.rst
@@ -0,0 +1,47 @@
.. _tune-mlflow:
Using MLFlow with Tune
======================
`MLFlow <https://mlflow.org/>`_ is an open source platform to manage the ML lifecycle, including experimentation,
reproducibility, deployment, and a central model registry. It currently offers four components, including
MLFlow Tracking to record and query experiments, including code, data, config, and results.
.. image:: /images/mlflow.png
:height: 80px
:alt: MLflow
:align: center
:target: https://www.mlflow.org/
Ray Tune currently offers two lightweight integrations for MLFlow Tracking.
One is the :ref:`MLFlowLoggerCallback <tune-mlflow-logger>`, which automatically logs
metrics reported to Tune to the MLFlow Tracking API.
The other one is the :ref:`@mlflow_mixin <tune-mlflow-mixin>` decorator, which can be
used with the function API. It automatically
initializes the MLFlow API with Tune's training information and creates a run for each Tune trial.
Then within your training function, you can just use the
MLFlow like you would normally do, e.g. using ``mlflow.log_metrics()`` or even ``mlflow.autolog()``
to log to your training process.
Please :doc:`see here </tune/examples/mlflow_example>` for a full example on how you can use either the
MLFlowLoggerCallback or the mlflow_mixin.
MLFlow AutoLogging
------------------
You can also check out :doc:`here </tune/examples/mlflow_ptl_example>` for an example on how you can leverage MLflow
autologging, in this case with Pytorch Lightning
MLFlow Logger API
-----------------
.. _tune-mlflow-logger:
.. autoclass:: ray.tune.integration.mlflow.MLFlowLoggerCallback
:noindex:
MLFlow Mixin API
----------------
.. _tune-mlflow-mixin:
.. autofunction:: ray.tune.integration.mlflow.mlflow_mixin
:noindex:
+9 -8
View File
@@ -70,11 +70,7 @@ An example of creating a custom logger can be found in :doc:`/tune/examples/logg
Trainable Logging
-----------------
By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if:
* you're using `Weights and Biases <https://www.wandb.com/>`_
* you're using `MLFlow <https://github.com/mlflow/mlflow/>`__
* you're trying to log images to Tensorboard.
By default, Tune only logs the *training result dictionaries* from your Trainable. However, you may want to visualize the model weights, model graph, or use a custom logging library that requires multi-process logging. For example, you may want to do this if you're trying to log images to Tensorboard.
You can do this in the trainable, as shown below:
@@ -163,12 +159,17 @@ CSVLogger
.. autoclass:: ray.tune.logger.CSVLoggerCallback
MLFLowLogger
MLFlowLogger
------------
Tune also provides a default logger for `MLFlow <https://mlflow.org>`_. You can install MLFlow via ``pip install mlflow``. An example can be found in :doc:`/tune/examples/mlflow_example`. Note that this currently does not include artifact logging support. For this, you can use the native MLFlow APIs inside your Trainable definition.
Tune also provides a default logger for `MLFlow <https://mlflow.org>`_. You can install MLFlow via ``pip install mlflow``.
You can see the :doc:`tutorial here </tune/tutorials/tune-mlflow>`.
.. autoclass:: ray.tune.logger.MLFLowLogger
WandbLogger
-----------
Tune also provides a default logger for `Weights & Biases <https://www.wandb.com/>`_. You can install Wandb via ``pip install wandb``.
You can see the :doc:`tutorial here </tune/tutorials/tune-wandb>`
.. _logger-interface:
+11 -6
View File
@@ -192,10 +192,18 @@ For a high-level overview, see this example:
# Sample a integer uniformly between -9 (inclusive) and 15 (exclusive)
"randint": tune.randint(-9, 15),
# Sample a integer uniformly between 1 (inclusive) and 10 (exclusive),
# while sampling in log space
"lograndint": tune.lograndint(1, 10),
# Sample a random uniformly between -21 (inclusive) and 12 (inclusive (!))
# rounding to increments of 3 (includes 12)
"qrandint": tune.qrandint(-21, 12, 3),
# Sample a integer uniformly between 1 (inclusive) and 10 (inclusive (!)),
# while sampling in log space and rounding to increments of 2
"qlograndint": tune.qlograndint(1, 10, 2),
# Sample an option uniformly from the specified choices
"choice": tune.choice(["a", "b", "c"]),
@@ -263,10 +271,7 @@ Grid Search API
.. autofunction:: ray.tune.grid_search
Internals
---------
References
----------
BasicVariantGenerator
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: ray.tune.suggest.BasicVariantGenerator
See also :ref:`tune-basicvariant`.
+19
View File
@@ -22,6 +22,10 @@ Summary
- Summary
- Website
- Code Example
* - :ref:`Random search/grid search <tune-basicvariant>`
- Random search/grid search
-
- :doc:`/tune/examples/tune_basic_example`
* - :ref:`AxSearch <tune-ax>`
- Bayesian/Bandit Optimization
- [`Ax <https://ax.dev/>`__]
@@ -123,6 +127,21 @@ identifier.
.. note:: This is currently not implemented for: AxSearch, TuneBOHB, SigOptSearch, and DragonflySearch.
.. _tune-basicvariant:
Random search and grid search (tune.suggest.basic_variant.BasicVariantGenerator)
--------------------------------------------------------------------------------
The default and most basic way to do hyperparameter search is via random and grid search.
Ray Tune does this through the :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>`
class that generates trial variants given a search space definition.
The :class:`BasicVariantGenerator <ray.tune.suggest.basic_variant.BasicVariantGenerator>` is used per
default if no search algorithm is passed to
:func:`tune.run() <ray.tune.run>`.
.. autoclass:: ray.tune.suggest.basic_variant.BasicVariantGenerator
.. _tune-ax:
Ax (tune.suggest.ax.AxSearch)
+2 -1
View File
@@ -13,7 +13,7 @@ If any example is broken, or if you'd like to add an example to this page, feel
General Examples
----------------
- :doc:`/tune/examples/tune_basic_example`: Simple example for doing a basic random and grid search.
- :doc:`/tune/examples/async_hyperband_example`: Example of using a simple tuning function with AsyncHyperBandScheduler.
- :doc:`/tune/examples/hyperband_function_example`: Example of using a Trainable function with HyperBandScheduler. Also uses the AsyncHyperBandScheduler.
- :doc:`/tune/examples/pbt_function`: Example of using the function API with a PopulationBasedTraining scheduler.
@@ -88,6 +88,7 @@ Wandb, MLFlow
- :ref:`Tutorial <tune-wandb>` for using `wandb <https://www.wandb.com/>`__ with Ray Tune
- :doc:`/tune/examples/wandb_example`: Example for using `Weights and Biases <https://www.wandb.com/>`__ with Ray Tune.
- :doc:`/tune/examples/mlflow_example`: Example for using `MLFlow <https://github.com/mlflow/mlflow/>`__ with Ray Tune.
- :doc:`/tune/examples/mlflow_ptl_example`: Example for using `MLFlow <https://github.com/mlflow/mlflow/>`__ and `Pytorch Lightning <https://github.com/PyTorchLightning/pytorch-lightning>`_ with Ray Tune.
Tensorflow/Keras
~~~~~~~~~~~~~~~~
@@ -0,0 +1,6 @@
:orphan:
mlflow_ptl_example
~~~~~~~~~~~~~~~~~~
.. literalinclude:: /../../python/ray/tune/examples/mlflow_ptl.py
@@ -0,0 +1,6 @@
:orphan:
tune_basic_example
~~~~~~~~~~~~~~~~~~
.. literalinclude:: /../../python/ray/tune/examples/tune_basic_example.py
+2
View File
@@ -417,6 +417,8 @@ actors.
to the object ref returned by the put exists. This only applies to the specific
ref returned by put, not refs in general or copies of that refs.
See also: `object spilling <advanced.html#object-spilling>`__.
Remote Classes (Actors)
-----------------------
@@ -7,7 +7,7 @@ import java.util.Random;
public class ActorId extends BaseId implements Serializable {
private static final int UNIQUE_BYTES_LENGTH = 4;
private static final int UNIQUE_BYTES_LENGTH = 12;
public static final int LENGTH = JobId.LENGTH + UNIQUE_BYTES_LENGTH;
@@ -10,7 +10,7 @@ import java.util.Random;
*/
public class ObjectId extends BaseId implements Serializable {
public static final int LENGTH = 20;
public static final int LENGTH = 28;
/**
* Create an ObjectId from a ByteBuffer.
@@ -11,7 +11,7 @@ import java.util.Random;
*/
public class UniqueId extends BaseId implements Serializable {
public static final int LENGTH = 20;
public static final int LENGTH = 28;
public static final UniqueId NIL = genNil();
/**
@@ -1,12 +1,13 @@
package io.ray.api.options;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
/**
* The options class for RayCall or ActorCreation.
*/
public abstract class BaseTaskOptions {
public abstract class BaseTaskOptions implements Serializable {
public final Map<String, Double> resources;
@@ -72,9 +72,8 @@ public interface RayRuntime {
*
* @param objectRefs The object references to free.
* @param localOnly Whether only free objects for local object store or not.
* @param deleteCreatingTasks Whether also delete objects' creating tasks from GCS.
*/
void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks);
void free(List<ObjectRef<?>> objectRefs, boolean localOnly);
/**
* Set the resource for the specific node.
@@ -100,9 +100,9 @@ public abstract class AbstractRayRuntime implements RayRuntimeInternal {
}
@Override
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly, boolean deleteCreatingTasks) {
public void free(List<ObjectRef<?>> objectRefs, boolean localOnly) {
objectStore.delete(objectRefs.stream().map(ref -> ((ObjectRefImpl<?>) ref).getId()).collect(
Collectors.toList()), localOnly, deleteCreatingTasks);
Collectors.toList()), localOnly);
}
@Override
@@ -3,10 +3,8 @@ package io.ray.runtime.gcs;
import com.google.common.base.Preconditions;
import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.ActorId;
import io.ray.api.id.BaseId;
import io.ray.api.id.JobId;
import io.ray.api.id.PlacementGroupId;
import io.ray.api.id.TaskId;
import io.ray.api.id.UniqueId;
import io.ray.api.placementgroup.PlacementGroup;
import io.ray.api.runtimecontext.NodeInfo;
@@ -14,12 +12,10 @@ import io.ray.runtime.generated.Gcs;
import io.ray.runtime.generated.Gcs.GcsNodeInfo;
import io.ray.runtime.generated.Gcs.TablePrefix;
import io.ray.runtime.placementgroup.PlacementGroupUtils;
import io.ray.runtime.util.IdUtil;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang3.ArrayUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -31,25 +27,10 @@ public class GcsClient {
private static Logger LOGGER = LoggerFactory.getLogger(GcsClient.class);
private RedisClient primary;
private List<RedisClient> shards;
private GlobalStateAccessor globalStateAccessor;
public GcsClient(String redisAddress, String redisPassword) {
primary = new RedisClient(redisAddress, redisPassword);
int numShards = 0;
try {
numShards = Integer.valueOf(primary.get("NumRedisShards", null));
Preconditions.checkState(numShards > 0,
String.format("Expected at least one Redis shards, found %d.", numShards));
} catch (NumberFormatException e) {
throw new RuntimeException("Failed to get number of redis shards.", e);
}
List<byte[]> shardAddresses = primary.lrange("RedisShards".getBytes(), 0, -1);
Preconditions.checkState(shardAddresses.size() == numShards);
shards = shardAddresses.stream().map((byte[] address) -> {
return new RedisClient(new String(address), redisPassword);
}).collect(Collectors.toList());
globalStateAccessor = GlobalStateAccessor.getInstance(redisAddress, redisPassword);
}
@@ -163,16 +144,6 @@ public class GcsClient {
return actorTableData.getNumRestarts() != 0;
}
/**
* Query whether the raylet task exists in Gcs.
*/
public boolean rayletTaskExistsInGcs(TaskId taskId) {
byte[] key = ArrayUtils.addAll(TablePrefix.RAYLET_TASK.toString().getBytes(),
taskId.getBytes());
RedisClient client = getShardClient(taskId);
return client.exists(key);
}
public JobId nextJobId() {
int jobCounter = (int) primary.incr("JobCounter".getBytes());
return JobId.fromInt(jobCounter);
@@ -186,10 +157,4 @@ public class GcsClient {
LOGGER.debug("Destroying global state accessor.");
GlobalStateAccessor.destroyInstance();
}
private RedisClient getShardClient(BaseId key) {
return shards.get((int) Long.remainderUnsigned(IdUtil.murmurHashCode(key),
shards.size()));
}
}
@@ -21,7 +21,7 @@ public class LocalModeObjectStore extends ObjectStore {
private static final Logger LOGGER = LoggerFactory.getLogger(LocalModeObjectStore.class);
private static final int GET_CHECK_INTERVAL_MS = 100;
private static final int GET_CHECK_INTERVAL_MS = 1;
private final Map<ObjectId, NativeRayObject> pool = new ConcurrentHashMap<>();
private final List<Consumer<ObjectId>> objectPutCallbacks = new ArrayList<>();
@@ -93,7 +93,7 @@ public class LocalModeObjectStore extends ObjectStore {
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
public void delete(List<ObjectId> objectIds, boolean localOnly) {
for (ObjectId objectId : objectIds) {
pool.remove(objectId);
}
@@ -50,8 +50,8 @@ public class NativeObjectStore extends ObjectStore {
}
@Override
public void delete(List<ObjectId> objectIds, boolean localOnly, boolean deleteCreatingTasks) {
nativeDelete(toBinaryList(objectIds), localOnly, deleteCreatingTasks);
public void delete(List<ObjectId> objectIds, boolean localOnly) {
nativeDelete(toBinaryList(objectIds), localOnly);
}
@Override
@@ -116,8 +116,7 @@ public class NativeObjectStore extends ObjectStore {
private static native List<Boolean> nativeWait(List<byte[]> objectIds, int numObjects,
long timeoutMs);
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly,
boolean deleteCreatingTasks);
private static native void nativeDelete(List<byte[]> objectIds, boolean localOnly);
private static native void nativeAddLocalReference(byte[] workerId, byte[] objectId);
@@ -167,10 +167,8 @@ public abstract class ObjectStore {
*
* @param objectIds IDs of the objects to delete.
* @param localOnly Whether only delete the objects in local node, or all nodes in the cluster.
* @param deleteCreatingTasks Whether also delete the tasks that created these objects.
*/
public abstract void delete(List<ObjectId> objectIds, boolean localOnly,
boolean deleteCreatingTasks);
public abstract void delete(List<ObjectId> objectIds, boolean localOnly);
/**
* Increase the local reference count for this object ID.
@@ -75,14 +75,18 @@ public class RunManager {
// address info of the local node.
String script = String.format("import ray;"
+ " print(ray._private.services.get_address_info_from_redis("
+ "'%s', '%s', redis_password='%s', no_warning=True))",
+ "'%s', '%s', redis_password='%s'))",
rayConfig.getRedisAddress(), rayConfig.nodeIp, rayConfig.redisPassword);
List<String> command = Arrays.asList("python", "-c", script);
String output = null;
try {
output = runCommand(command);
JsonObject addressInfo = new JsonParser().parse(output).getAsJsonObject();
// NOTE(kfstorm): We only parse the last line here in case there are some warning
// messages appear at the beginning.
String[] lines = output.split(System.lineSeparator());
String lastLine = lines[lines.length - 1];
JsonObject addressInfo = new JsonParser().parse(lastLine).getAsJsonObject();
rayConfig.rayletSocketName = addressInfo.get("raylet_socket_name").getAsString();
rayConfig.objectStoreSocketName = addressInfo.get("object_store_address").getAsString();
rayConfig.nodeManagerPort = addressInfo.get("node_manager_port").getAsInt();
@@ -1,7 +1,6 @@
package io.ray.runtime.util;
import io.ray.api.id.ActorId;
import io.ray.api.id.BaseId;
import io.ray.api.id.ObjectId;
import io.ray.api.id.TaskId;
@@ -11,74 +10,6 @@ import io.ray.api.id.TaskId;
*/
public class IdUtil {
/**
* Compute the murmur hash code of this ID.
*/
public static long murmurHashCode(BaseId id) {
return murmurHash64A(id.getBytes(), id.size(), 0);
}
/**
* This method is the same as `Hash()` method of `ID` class in ray/src/ray/common/id.h
*/
private static long murmurHash64A(byte[] data, int length, int seed) {
final long m = 0xc6a4a7935bd1e995L;
final int r = 47;
long h = (seed & 0xFFFFFFFFL) ^ (length * m);
int length8 = length / 8;
for (int i = 0; i < length8; i++) {
final int i8 = i * 8;
long k = ((long) data[i8] & 0xff)
+ (((long) data[i8 + 1] & 0xff) << 8)
+ (((long) data[i8 + 2] & 0xff) << 16)
+ (((long) data[i8 + 3] & 0xff) << 24)
+ (((long) data[i8 + 4] & 0xff) << 32)
+ (((long) data[i8 + 5] & 0xff) << 40)
+ (((long) data[i8 + 6] & 0xff) << 48)
+ (((long) data[i8 + 7] & 0xff) << 56);
k *= m;
k ^= k >>> r;
k *= m;
h ^= k;
h *= m;
}
final int remaining = length % 8;
if (remaining >= 7) {
h ^= (long) (data[(length & ~7) + 6] & 0xff) << 48;
}
if (remaining >= 6) {
h ^= (long) (data[(length & ~7) + 5] & 0xff) << 40;
}
if (remaining >= 5) {
h ^= (long) (data[(length & ~7) + 4] & 0xff) << 32;
}
if (remaining >= 4) {
h ^= (long) (data[(length & ~7) + 3] & 0xff) << 24;
}
if (remaining >= 3) {
h ^= (long) (data[(length & ~7) + 2] & 0xff) << 16;
}
if (remaining >= 2) {
h ^= (long) (data[(length & ~7) + 1] & 0xff) << 8;
}
if (remaining >= 1) {
h ^= (long) (data[length & ~7] & 0xff);
h *= m;
}
h ^= h >>> r;
h *= m;
h ^= h >>> r;
return h;
}
/**
* Compute the actor ID of the task which created this object.
* @return The actor ID of the task which created this object.
@@ -1,7 +1,6 @@
package io.ray.runtime;
import io.ray.api.id.UniqueId;
import io.ray.runtime.util.IdUtil;
import java.nio.ByteBuffer;
import java.util.Arrays;
import javax.xml.bind.DatatypeConverter;
@@ -13,12 +12,12 @@ public class UniqueIdTest {
@Test
public void testConstructUniqueId() {
// Test `fromHexString()`
UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF00");
Assert.assertEquals("00000000123456789abcdef123456789abcdef00", id1.toString());
UniqueId id1 = UniqueId.fromHexString("00000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
Assert.assertEquals("00000000123456789abcdef123456789abcdef0123456789abcdef00", id1.toString());
Assert.assertFalse(id1.isNil());
try {
UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF00");
UniqueId id2 = UniqueId.fromHexString("000000123456789ABCDEF123456789ABCDEF0123456789ABCDEF00");
// This shouldn't be happened.
Assert.assertTrue(false);
} catch (IllegalArgumentException e) {
@@ -34,23 +33,16 @@ public class UniqueIdTest {
}
// Test `fromByteBuffer()`
byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF01234567");
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 20);
byte[] bytes = DatatypeConverter.parseHexBinary("0123456789ABCDEF0123456789ABCDEF012345670123456789ABCDEF");
ByteBuffer byteBuffer = ByteBuffer.wrap(bytes, 0, 28);
UniqueId id4 = UniqueId.fromByteBuffer(byteBuffer);
Assert.assertTrue(Arrays.equals(bytes, id4.getBytes()));
Assert.assertEquals("0123456789abcdef0123456789abcdef01234567", id4.toString());
Assert.assertEquals("0123456789abcdef0123456789abcdef012345670123456789abcdef", id4.toString());
// Test `genNil()`
UniqueId id6 = UniqueId.NIL;
Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString());
Assert.assertEquals("FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF".toLowerCase(), id6.toString());
Assert.assertTrue(id6.isNil());
}
@Test
void testMurmurHash() {
UniqueId id = UniqueId.fromHexString("3131313131313131313132323232323232323232");
long remainder = Long.remainderUnsigned(IdUtil.murmurHashCode(id), 1000000000);
Assert.assertEquals(remainder, 787616861);
}
}
@@ -128,7 +128,7 @@ public class ActorTest extends BaseTest {
ObjectRef value = counter.task(Counter::getValue).remote();
Assert.assertEquals(100, value.get());
// Delete the object from the object store.
Ray.internal().free(ImmutableList.of(value), false, false);
Ray.internal().free(ImmutableList.of(value), false);
// Wait for delete RPC to propagate
TimeUnit.SECONDS.sleep(1);
// Free deletes from in-memory store.
@@ -138,7 +138,7 @@ public class ActorTest extends BaseTest {
ObjectRef<TestUtils.LargeObject> largeValue = counter.task(Counter::createLargeObject).remote();
Assert.assertTrue(largeValue.get() instanceof TestUtils.LargeObject);
// Delete the object from the object store.
Ray.internal().free(ImmutableList.of(largeValue), false, false);
Ray.internal().free(ImmutableList.of(largeValue), false);
// Wait for delete RPC to propagate
TimeUnit.SECONDS.sleep(1);
// Free deletes big objects from plasma store.
@@ -15,7 +15,8 @@ public class DynamicResourceTest extends BaseTest {
return "hi";
}
@Test(groups = {"cluster"})
// Dynamic resources not supported yet.
@Test(groups = {"cluster"}, enabled = false)
public void testSetResource() {
// Call a task in advance to warm up the cluster to avoid being too slow to start workers.
TestUtils.warmUpCluster();
@@ -3,9 +3,7 @@ package io.ray.test;
import com.google.common.collect.ImmutableList;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.TaskId;
import io.ray.runtime.object.ObjectRefImpl;
import java.util.Arrays;
import org.testng.Assert;
import org.testng.annotations.Test;
@@ -20,7 +18,7 @@ public class PlasmaFreeTest extends BaseTest {
ObjectRef<String> helloId = Ray.task(PlasmaFreeTest::hello).remote();
String helloString = helloId.get();
Assert.assertEquals("hello", helloString);
Ray.internal().free(ImmutableList.of(helloId), true, false);
Ray.internal().free(ImmutableList.of(helloId), true);
final boolean result = TestUtils.waitForCondition(() ->
!TestUtils.getRuntime().getObjectStore()
@@ -32,19 +30,4 @@ public class PlasmaFreeTest extends BaseTest {
Assert.assertFalse(result);
}
}
@Test(groups = {"cluster"})
public void testDeleteCreatingTasks() {
ObjectRef<String> helloId = Ray.task(PlasmaFreeTest::hello).remote();
Assert.assertEquals("hello", helloId.get());
Ray.internal().free(ImmutableList.of(helloId), true, true);
TaskId taskId = TaskId.fromBytes(
Arrays.copyOf(((ObjectRefImpl<String>) helloId).getId().getBytes(), TaskId.LENGTH));
final boolean result = TestUtils.waitForCondition(
() -> !TestUtils.getRuntime().getGcsClient()
.rayletTaskExistsInGcs(taskId), 50);
Assert.assertTrue(result);
}
}
+49
View File
@@ -0,0 +1,49 @@
import os
from contextlib import contextmanager
from functools import wraps
client_mode_enabled = os.environ.get("RAY_CLIENT_MODE", "0") == "1"
_client_hook_enabled = True
def _enable_client_hook(val: bool):
global _client_hook_enabled
_client_hook_enabled = val
def _disable_client_hook():
global _client_hook_enabled
out = _client_hook_enabled
_client_hook_enabled = False
return out
def _explicitly_enable_client_mode():
global client_mode_enabled
client_mode_enabled = True
@contextmanager
def disable_client_hook():
val = _disable_client_hook()
try:
yield None
finally:
_enable_client_hook(val)
def client_mode_hook(func):
"""
Decorator for ray module methods to delegate to ray client
"""
from ray.experimental.client import ray
@wraps(func)
def wrapper(*args, **kwargs):
global _client_hook_enabled
if client_mode_enabled and _client_hook_enabled:
return getattr(ray, func.__name__)(*args, **kwargs)
return func(*args, **kwargs)
return wrapper
@@ -0,0 +1,83 @@
import inspect
import logging
import sys
from ray.experimental.client.ray_client_helpers import ray_start_client_server
from ray._private.ray_microbenchmark_helpers import timeit
from ray._private.ray_microbenchmark_helpers import ray_setup_and_teardown
def benchmark_get_calls(ray):
value = ray.put(0)
def get_small():
ray.get(value)
timeit("client: get calls", get_small)
def benchmark_put_calls(ray):
def put_small():
ray.put(0)
timeit("client: put calls", put_small)
def benchmark_remote_put_calls(ray):
@ray.remote
def do_put_small():
for _ in range(100):
ray.put(0)
def put_multi_small():
ray.get([do_put_small.remote() for _ in range(10)])
timeit("client: remote put calls", put_multi_small, 1000)
def benchmark_simple_actor(ray):
@ray.remote(num_cpus=0)
class Actor:
def small_value(self):
return b"ok"
def small_value_arg(self, x):
return b"ok"
def small_value_batch(self, n):
ray.get([self.small_value.remote() for _ in range(n)])
a = Actor.remote()
def actor_sync():
ray.get(a.small_value.remote())
timeit("client: 1:1 actor calls sync", actor_sync)
def actor_async():
ray.get([a.small_value.remote() for _ in range(1000)])
timeit("client: 1:1 actor calls async", actor_async, 1000)
a = Actor.options(max_concurrency=16).remote()
def actor_concurrent():
ray.get([a.small_value.remote() for _ in range(1000)])
timeit("client: 1:1 actor calls concurrent", actor_concurrent, 1000)
def main():
system_config = {"put_small_object_in_memory_store": True}
with ray_setup_and_teardown(
logging_level=logging.WARNING, _system_config=system_config):
for name, obj in inspect.getmembers(sys.modules[__name__]):
if not name.startswith("benchmark_"):
continue
with ray_start_client_server() as ray:
obj(ray)
if __name__ == "__main__":
main()
@@ -0,0 +1,39 @@
import time
import os
import ray
import numpy as np
from contextlib import contextmanager
# Only run tests matching this filter pattern.
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
def timeit(name, fn, multiplier=1):
if filter_pattern not in name:
return
# warmup
start = time.time()
while time.time() - start < 1:
fn()
# real run
stats = []
for _ in range(4):
start = time.time()
count = 0
while time.time() - start < 2:
fn()
count += 1
end = time.time()
stats.append(multiplier * count / (end - start))
print(name, "per second", round(np.mean(stats), 2), "+-",
round(np.std(stats), 2))
@contextmanager
def ray_setup_and_teardown(**init_args):
ray.init(**init_args)
try:
yield None
finally:
ray.shutdown()
+12 -13
View File
@@ -279,8 +279,7 @@ def get_address_info_from_redis_helper(redis_address,
def get_address_info_from_redis(redis_address,
node_ip_address,
num_retries=5,
redis_password=None,
no_warning=False):
redis_password=None):
counter = 0
while True:
try:
@@ -291,11 +290,10 @@ def get_address_info_from_redis(redis_address,
raise
# Some of the information may not be in Redis yet, so wait a little
# bit.
if not no_warning:
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
logger.warning(
"Some processes that the driver needs to connect to have "
"not registered with Redis, so retrying. Have you run "
"'ray start' on this node?")
time.sleep(1)
counter += 1
@@ -1618,12 +1616,13 @@ def determine_plasma_store_config(object_store_memory,
logger.warning(
"WARNING: The object store is using {} instead of "
"/dev/shm because /dev/shm has only {} bytes available. "
"This may slow down performance! You may be able to free "
"up space by deleting files in /dev/shm or terminating "
"any running plasma_store_server processes. If you are "
"inside a Docker container, you may need to pass an "
"argument with the flag '--shm-size' to 'docker run'.".
format(ray.utils.get_user_temp_dir(), shm_avail))
"This will harm performance! You may be able to free up "
"space by deleting files in /dev/shm. If you are inside a "
"Docker container, you can increase /dev/shm size by "
"passing '--shm-size=Xgb' to 'docker run' (or add it to "
"the run_options list in a Ray cluster config). Make sure "
"to set this to more than 2gb.".format(
ray.utils.get_user_temp_dir(), shm_avail))
else:
plasma_directory = ray.utils.get_user_temp_dir()
+19 -19
View File
@@ -107,6 +107,10 @@ from ray.exceptions import (
TaskCancelledError
)
from ray.utils import decode
from ray._private.client_mode_hook import (
_enable_client_hook,
_disable_client_hook,
)
import msgpack
cimport cpython
@@ -558,6 +562,7 @@ cdef CRayStatus task_execution_handler(
with gil:
try:
client_was_enabled = _disable_client_hook()
try:
# The call to execute_task should never raise an exception. If
# it does, that indicates that there was an internal error.
@@ -582,6 +587,8 @@ cdef CRayStatus task_execution_handler(
else:
logger.exception("SystemExit was raised from the worker")
return CRayStatus.UnexpectedSystemExit()
finally:
_enable_client_hook(client_was_enabled)
return CRayStatus.OK()
@@ -638,9 +645,11 @@ cdef c_vector[c_string] spill_objects_handler(
return return_urls
cdef void restore_spilled_objects_handler(
cdef int64_t restore_spilled_objects_handler(
const c_vector[CObjectID]& object_ids_to_restore,
const c_vector[c_string]& object_urls) nogil:
cdef:
int64_t bytes_restored = 0
with gil:
urls = []
size = object_urls.size()
@@ -651,7 +660,8 @@ cdef void restore_spilled_objects_handler(
with ray.worker._changeproctitle(
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER,
ray_constants.WORKER_PROCESS_TYPE_RESTORE_WORKER_IDLE):
external_storage.restore_spilled_objects(object_refs, urls)
bytes_restored = external_storage.restore_spilled_objects(
object_refs, urls)
except Exception:
exception_str = (
"An unexpected internal error occurred while the IO worker "
@@ -662,6 +672,7 @@ cdef void restore_spilled_objects_handler(
"restore_spilled_objects_error",
traceback.format_exc() + exception_str,
job_id=None)
return bytes_restored
cdef void delete_spilled_objects_handler(
@@ -873,7 +884,8 @@ cdef class CoreWorker:
return self.plasma_event_handler
def get_objects(self, object_refs, TaskID current_task_id,
int64_t timeout_ms=-1, plasma_objects_only=False):
int64_t timeout_ms=-1,
plasma_objects_only=False):
cdef:
c_vector[shared_ptr[CRayObject]] results
CTaskID c_task_id = current_task_id.native()
@@ -1004,7 +1016,7 @@ cdef class CoreWorker:
return c_object_id.Binary()
def wait(self, object_refs, int num_returns, int64_t timeout_ms,
TaskID current_task_id):
TaskID current_task_id, c_bool fetch_local):
cdef:
c_vector[CObjectID] wait_ids
c_vector[c_bool] results
@@ -1013,7 +1025,7 @@ cdef class CoreWorker:
wait_ids = ObjectRefsToVector(object_refs)
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker().Wait(
wait_ids, num_returns, timeout_ms, &results))
wait_ids, num_returns, timeout_ms, &results, fetch_local))
assert len(results) == len(object_refs)
@@ -1026,14 +1038,13 @@ cdef class CoreWorker:
return ready, not_ready
def free_objects(self, object_refs, c_bool local_only,
c_bool delete_creating_tasks):
def free_objects(self, object_refs, c_bool local_only):
cdef:
c_vector[CObjectID] free_ids = ObjectRefsToVector(object_refs)
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker().Delete(
free_ids, local_only, delete_creating_tasks))
free_ids, local_only))
def global_gc(self):
with nogil:
@@ -1573,17 +1584,6 @@ cdef class CoreWorker:
resource_name.encode("ascii"), capacity,
CNodeID.FromBinary(client_id.binary()))
def force_spill_objects(self, object_refs):
cdef c_vector[CObjectID] object_ids
object_ids = ObjectRefsToVector(object_refs)
assert not RayConfig.instance().automatic_object_deletion_enabled(), (
"Automatic object deletion is not supported for"
"force_spill_objects yet. Please set"
"automatic_object_deletion_enabled: False in Ray's system config.")
with nogil:
check_status(CCoreWorkerProcess.GetCoreWorker()
.SpillObjects(object_ids))
cdef void async_set_result(shared_ptr[CRayObject] obj,
CObjectID object_ref,
void *future) with gil:
+103 -74
View File
@@ -1,4 +1,4 @@
from collections import defaultdict, namedtuple
from collections import defaultdict, namedtuple, Counter
from typing import Any, Optional, Dict, List
from urllib3.exceptions import MaxRetryError
import copy
@@ -16,8 +16,10 @@ from ray.experimental.internal_kv import _internal_kv_put, \
from ray.autoscaler.tags import (
TAG_RAY_LAUNCH_CONFIG, TAG_RAY_RUNTIME_CONFIG,
TAG_RAY_FILE_MOUNTS_CONTENTS, TAG_RAY_NODE_STATUS, TAG_RAY_NODE_KIND,
TAG_RAY_USER_NODE_TYPE, STATUS_UP_TO_DATE, NODE_KIND_WORKER,
NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
TAG_RAY_USER_NODE_TYPE, STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
STATUS_SYNCING_FILES, STATUS_SETTING_UP, STATUS_UP_TO_DATE,
NODE_KIND_WORKER, NODE_KIND_UNMANAGED, NODE_KIND_HEAD)
from ray.autoscaler._private.legacy_info_string import legacy_log_info_string
from ray.autoscaler._private.providers import _get_node_provider
from ray.autoscaler._private.updater import NodeUpdaterThread
from ray.autoscaler._private.node_launcher import NodeLauncher
@@ -25,8 +27,8 @@ from ray.autoscaler._private.resource_demand_scheduler import \
get_bin_pack_residual, ResourceDemandScheduler, NodeType, NodeID, NodeIP, \
ResourceDict
from ray.autoscaler._private.util import ConcurrentCounter, validate_config, \
with_head_node_ip, hash_launch_conf, hash_runtime_conf, add_prefix, \
DEBUG_AUTOSCALING_STATUS, DEBUG_AUTOSCALING_ERROR
with_head_node_ip, hash_launch_conf, hash_runtime_conf, \
DEBUG_AUTOSCALING_ERROR, format_info_string
from ray.autoscaler._private.constants import \
AUTOSCALER_MAX_NUM_FAILURES, AUTOSCALER_MAX_LAUNCH_BATCH, \
AUTOSCALER_MAX_CONCURRENT_LAUNCHES, AUTOSCALER_UPDATE_INTERVAL_S, \
@@ -41,20 +43,23 @@ UpdateInstructions = namedtuple(
"UpdateInstructions",
["node_id", "init_commands", "start_ray_commands", "docker_config"])
AutoscalerSummary = namedtuple(
"AutoscalerSummary",
["active_nodes", "pending_nodes", "pending_launches", "failed_nodes"])
class StandardAutoscaler:
"""The autoscaling control loop for a Ray cluster.
There are two ways to start an autoscaling cluster: manually by running
`ray start --head --autoscaling-config=/path/to/config.yaml` on a
instance that has permission to launch other instances, or you can also use
`ray up /path/to/config.yaml` from your laptop, which will
configure the right AWS/Cloud roles automatically.
StandardAutoscaler's `update` method is periodically called by `monitor.py`
to add and remove nodes as necessary. Currently, load-based autoscaling is
not implemented, so all this class does is try to maintain a constant
cluster size.
`ray start --head --autoscaling-config=/path/to/config.yaml` on a instance
that has permission to launch other instances, or you can also use `ray up
/path/to/config.yaml` from your laptop, which will configure the right
AWS/Cloud roles automatically. See the documentation for a full definition
of autoscaling behavior:
https://docs.ray.io/en/master/cluster/autoscaling.html
StandardAutoscaler's `update` method is periodically called in
`monitor.py`'s monitoring loop.
StandardAutoscaler is also used to bootstrap clusters (by adding workers
until the cluster size that can handle the resource demand is met).
@@ -120,9 +125,6 @@ class StandardAutoscaler:
for local_path in self.config["file_mounts"].values():
assert os.path.exists(local_path)
# List of resource bundles the user is requesting of the cluster.
self.resource_demand_vector = []
logger.info("StandardAutoscaler: {}".format(self.config))
def update(self):
@@ -149,7 +151,6 @@ class StandardAutoscaler:
def _update(self):
now = time.time()
# Throttle autoscaling updates to this interval to avoid exceeding
# rate limits on API calls.
if now - self.last_update_time < self.update_interval_s:
@@ -162,7 +163,6 @@ class StandardAutoscaler:
self.provider.internal_ip(node_id)
for node_id in self.all_workers()
])
self.log_info_string(nodes)
# Terminate any idle or out of date nodes
last_used = self.load_metrics.last_used_time_by_ip
@@ -176,7 +176,7 @@ class StandardAutoscaler:
sorted_node_ids = self._sort_based_on_last_used(nodes, last_used)
# Don't terminate nodes needed by request_resources()
nodes_allowed_to_terminate: Dict[NodeID, bool] = {}
if self.resource_demand_vector:
if self.load_metrics.get_resource_requests():
nodes_allowed_to_terminate = self._get_nodes_allowed_to_terminate(
sorted_node_ids)
@@ -202,7 +202,6 @@ class StandardAutoscaler:
if nodes_to_terminate:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
self.log_info_string(nodes)
# Terminate nodes if there are too many
nodes_to_terminate = []
@@ -217,8 +216,6 @@ class StandardAutoscaler:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
self.log_info_string(nodes)
to_launch = self.resource_demand_scheduler.get_nodes_to_launch(
self.provider.non_terminated_nodes(tag_filters={}),
self.pending_launches.breakdown(),
@@ -226,7 +223,7 @@ class StandardAutoscaler:
self.load_metrics.get_resource_utilization(),
self.load_metrics.get_pending_placement_groups(),
self.load_metrics.get_static_node_resources_by_ip(),
ensure_min_cluster_size=self.resource_demand_vector)
ensure_min_cluster_size=self.load_metrics.get_resource_requests())
for node_type, count in to_launch.items():
self.launch_new_node(count, node_type=node_type)
@@ -256,7 +253,6 @@ class StandardAutoscaler:
self.provider.terminate_nodes(nodes_to_terminate)
nodes = self.workers()
self.log_info_string(nodes)
# Update nodes with out-of-date files.
# TODO(edoakes): Spawning these threads directly seems to cause
@@ -282,6 +278,9 @@ class StandardAutoscaler:
for node_id in nodes:
self.recover_if_needed(node_id, now)
logger.info(self.info_string())
legacy_log_info_string(self, nodes)
def _sort_based_on_last_used(self, nodes: List[NodeID],
last_used: Dict[str, float]) -> List[NodeID]:
"""Sort the nodes based on the last time they were used.
@@ -333,7 +332,7 @@ class StandardAutoscaler:
NodeIP,
ResourceDict] = \
self.load_metrics.get_static_node_resources_by_ip()
head_node_resources = static_nodes[head_ip]
head_node_resources = static_nodes.get(head_ip, {})
else:
head_node_resources = {}
@@ -362,7 +361,7 @@ class StandardAutoscaler:
used_resource_requests: List[ResourceDict]
_, used_resource_requests = \
get_bin_pack_residual(max_node_resources,
self.resource_demand_vector)
self.load_metrics.get_resource_requests())
# Remove the first entry (the head node).
max_node_resources.pop(0)
# Remove the first entry (the head node).
@@ -482,11 +481,13 @@ class StandardAutoscaler:
# for legacy yamls.
self.resource_demand_scheduler.reset_config(
self.provider, self.available_node_types,
self.config["max_workers"], upscaling_speed)
self.config["max_workers"], self.config["head_node_type"],
upscaling_speed)
else:
self.resource_demand_scheduler = ResourceDemandScheduler(
self.provider, self.available_node_types,
self.config["max_workers"], upscaling_speed)
self.config["max_workers"], self.config["head_node_type"],
upscaling_speed)
except Exception as e:
if errors_fatal:
@@ -532,15 +533,17 @@ class StandardAutoscaler:
if not self.can_update(node_id):
return
key = self.provider.internal_ip(node_id)
if key not in self.load_metrics.last_heartbeat_time_by_ip:
self.load_metrics.last_heartbeat_time_by_ip[key] = now
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[key]
delta = now - last_heartbeat_time
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
return
if key in self.load_metrics.last_heartbeat_time_by_ip:
last_heartbeat_time = self.load_metrics.last_heartbeat_time_by_ip[
key]
delta = now - last_heartbeat_time
if delta < AUTOSCALER_HEARTBEAT_TIMEOUT_S:
return
logger.warning("StandardAutoscaler: "
"{}: No heartbeat in {}s, "
"restarting Ray to recover...".format(node_id, delta))
"{}: No recent heartbeat, "
"restarting Ray to recover...".format(node_id))
updater = NodeUpdaterThread(
node_id=node_id,
provider_config=self.config["provider"],
@@ -677,43 +680,6 @@ class StandardAutoscaler:
return self.provider.non_terminated_nodes(
tag_filters={TAG_RAY_NODE_KIND: NODE_KIND_UNMANAGED})
def log_info_string(self, nodes):
tmp = "Cluster status: "
tmp += self.info_string(nodes)
tmp += "\n"
tmp += self.load_metrics.info_string()
tmp += "\n"
tmp += self.resource_demand_scheduler.debug_string(
nodes, self.pending_launches.breakdown(),
self.load_metrics.get_resource_utilization())
if _internal_kv_initialized():
_internal_kv_put(DEBUG_AUTOSCALING_STATUS, tmp, overwrite=True)
if self.prefix_cluster_info:
tmp = add_prefix(tmp, self.config["cluster_name"])
logger.debug(tmp)
def info_string(self, nodes):
suffix = ""
if self.updaters:
suffix += " ({} updating)".format(len(self.updaters))
if self.num_failed_updates:
suffix += " ({} failed to update)".format(
len(self.num_failed_updates))
return "{} nodes{}".format(len(nodes), suffix)
def request_resources(self, resources: List[dict]):
"""Called by monitor to request resources.
Args:
resources: A list of resource bundles.
"""
if resources:
logger.info(
"StandardAutoscaler: resource_requests={}".format(resources))
assert isinstance(resources, list), resources
self.resource_demand_vector = resources
def kill_workers(self):
logger.error("StandardAutoscaler: kill_workers triggered")
nodes = self.workers()
@@ -721,3 +687,66 @@ class StandardAutoscaler:
self.provider.terminate_nodes(nodes)
logger.error("StandardAutoscaler: terminated {} node(s)".format(
len(nodes)))
def summary(self):
"""Summarizes the active, pending, and failed node launches.
An active node is a node whose raylet is actively reporting heartbeats.
A pending node is non-active node whose node tag is uninitialized,
waiting for ssh, syncing files, or setting up.
If a node is not pending or active, it is failed.
Returns:
AutoscalerSummary: The summary.
"""
all_node_ids = self.provider.non_terminated_nodes(tag_filters={})
active_nodes = Counter()
pending_nodes = []
failed_nodes = []
for node_id in all_node_ids:
ip = self.provider.internal_ip(node_id)
node_tags = self.provider.node_tags(node_id)
if node_tags[TAG_RAY_NODE_KIND] == NODE_KIND_UNMANAGED:
continue
node_type = node_tags[TAG_RAY_USER_NODE_TYPE]
# TODO (Alex): If a node's raylet has died, it shouldn't be marked
# as active.
is_active = self.load_metrics.is_active(ip)
if is_active:
active_nodes[node_type] += 1
else:
status = node_tags[TAG_RAY_NODE_STATUS]
pending_states = [
STATUS_UNINITIALIZED, STATUS_WAITING_FOR_SSH,
STATUS_SYNCING_FILES, STATUS_SETTING_UP
]
is_pending = status in pending_states
if is_pending:
pending_nodes.append((ip, node_type))
else:
# TODO (Alex): Failed nodes are now immediately killed, so
# this list will almost always be empty. We should ideally
# keep a cache of recently failed nodes and their startup
# logs.
failed_nodes.append((ip, node_type))
# The concurrent counter leaves some 0 counts in, so we need to
# manually filter those out.
pending_launches = {}
for node_type, count in self.pending_launches.breakdown().items():
if count:
pending_launches[node_type] = count
return AutoscalerSummary(
active_nodes=active_nodes,
pending_nodes=pending_nodes,
pending_launches=pending_launches,
failed_nodes=failed_nodes)
def info_string(self):
lm_summary = self.load_metrics.summary()
autoscaler_summary = self.summary()
return "\n" + format_info_string(lm_summary, autoscaler_summary)
@@ -35,6 +35,8 @@ logger = logging.getLogger(__name__)
HASH_MAX_LENGTH = 10
KUBECTL_RSYNC = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "kubernetes/kubectl-rsync.sh")
MAX_HOME_RETRIES = 3
HOME_RETRY_DELAY_S = 5
_config = {"use_login_shells": True, "silent_rsync": True}
@@ -248,16 +250,31 @@ class KubernetesCommandRunner(CommandRunnerInterface):
@property
def _home(self):
if self._home_cached is not None:
return self._home_cached
for _ in range(MAX_HOME_RETRIES - 1):
try:
self._home_cached = self._try_to_get_home()
return self._home_cached
except Exception:
# TODO (Dmitri): Identify the exception we're trying to avoid.
logger.info("Error reading container's home directory. "
f"Retrying in {HOME_RETRY_DELAY_S} seconds.")
time.sleep(HOME_RETRY_DELAY_S)
# Last try
self._home_cached = self._try_to_get_home()
return self._home_cached
def _try_to_get_home(self):
# TODO (Dmitri): Think about how to use the node's HOME variable
# without making an extra kubectl exec call.
if self._home_cached is None:
cmd = self.kubectl + [
"exec", "-it", self.node_id, "--", "printenv", "HOME"
]
joined_cmd = " ".join(cmd)
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
self._home_cached = raw_out.decode().strip("\n\r")
return self._home_cached
cmd = self.kubectl + [
"exec", "-it", self.node_id, "--", "printenv", "HOME"
]
joined_cmd = " ".join(cmd)
raw_out = self.process_runner.check_output(joined_cmd, shell=True)
home = raw_out.decode().strip("\n\r")
return home
class SSHOptions:
+16 -3
View File
@@ -43,6 +43,10 @@ from ray.worker import global_worker # type: ignore
from ray.util.debug import log_once
import ray.autoscaler._private.subprocess_output_util as cmd_output_util
from ray.autoscaler._private.load_metrics import LoadMetricsSummary
from ray.autoscaler._private.autoscaler import AutoscalerSummary
from ray.autoscaler._private.util import format_info_string, \
format_info_string_no_node_types
logger = logging.getLogger(__name__)
@@ -94,6 +98,14 @@ def debug_status() -> str:
status = "No cluster status."
else:
status = status.decode("utf-8")
as_dict = json.loads(status)
lm_summary = LoadMetricsSummary(**as_dict["load_metrics_report"])
if "autoscaler_report" in as_dict:
autoscaler_summary = AutoscalerSummary(
**as_dict["autoscaler_report"])
status = format_info_string(lm_summary, autoscaler_summary)
else:
status = format_info_string_no_node_types(lm_summary)
if error:
status += "\n"
status += error.decode("utf-8")
@@ -280,9 +292,10 @@ def _bootstrap_config(config: Dict[str, Any],
f"Failed to autodetect node resources: {str(exc)}. "
"You can see full stack trace with higher verbosity.")
# NOTE: if `resources` field is missing, validate_config for non-AWS will
# fail (the schema error will ask the user to manually fill the resources)
# as we currently support autofilling resources for AWS instances only.
# NOTE: if `resources` field is missing, validate_config for providers
# other than AWS and Kubernetes will fail (the schema error will ask the
# user to manually fill the resources) as we currently support autofilling
# resources for AWS and Kubernetes only.
validate_config(config)
resolved_config = provider_cls.bootstrap_config(config)
@@ -60,6 +60,13 @@ def bootstrap_kubernetes(config):
def fillout_resources_kubernetes(config):
"""Fills CPU and GPU resources by reading pod spec of each available node
type.
For each node type and each of CPU/GPU, looks at container's resources
and limits, takes min of the two. The result is rounded up, as Ray does
not currently support fractional CPU.
"""
if "available_node_types" not in config:
return config["available_node_types"]
node_types = copy.deepcopy(config["available_node_types"])
@@ -96,20 +103,47 @@ def get_resource(container_resources, resource_name):
limit = _get_resource(
container_resources, resource_name, field_name="limits")
resource = min(request, limit)
# float("inf") value means the resource wasn't detected in either
# requests or limits
return 0 if resource == float("inf") else int(resource)
def _get_resource(container_resources, resource_name, field_name):
if (field_name in container_resources
and resource_name in container_resources[field_name]):
return _parse_resource(container_resources[field_name][resource_name])
else:
"""Returns the resource quantity.
The amount of resource is rounded up to nearest integer.
Returns float("inf") if the resource is not present.
Args:
container_resources (dict): Container's resource field.
resource_name (str): One of 'cpu' or 'gpu'.
field_name (str): One of 'requests' or 'limits'.
Returns:
Union[int, float]: Detected resource quantity.
"""
if field_name not in container_resources:
# No limit/resource field.
return float("inf")
resources = container_resources[field_name]
# Look for keys containing the resource_name. For example,
# the key 'nvidia.com/gpu' contains the key 'gpu'.
matching_keys = [key for key in resources if resource_name in key.lower()]
if len(matching_keys) == 0:
return float("inf")
if len(matching_keys) > 1:
# Should have only one match -- mostly relevant for gpu.
raise ValueError(f"Multiple {resource_name} types not supported.")
# E.g. 'nvidia.com/gpu' or 'cpu'.
resource_key = matching_keys.pop()
resource_quantity = resources[resource_key]
return _parse_resource(resource_quantity)
def _parse_resource(resource):
resource_str = str(resource)
if resource_str[-1] == "m":
# For example, '500m' rounds up to 1.
return math.ceil(int(resource_str[:-1]) / 1000)
else:
return int(resource_str)
@@ -0,0 +1,35 @@
import logging
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS_LEGACY
from ray.experimental.internal_kv import _internal_kv_put, \
_internal_kv_initialized
"""This file provides legacy support for the old info string in order to
ensure the dashboard's `api/cluster_status` does not break backwards
compatibilty.
"""
logger = logging.getLogger(__name__)
def legacy_log_info_string(autoscaler, nodes):
tmp = "Cluster status: "
tmp += info_string(autoscaler, nodes)
tmp += "\n"
tmp += autoscaler.load_metrics.info_string()
tmp += "\n"
tmp += autoscaler.resource_demand_scheduler.debug_string(
nodes, autoscaler.pending_launches.breakdown(),
autoscaler.load_metrics.get_resource_utilization())
if _internal_kv_initialized():
_internal_kv_put(DEBUG_AUTOSCALING_STATUS_LEGACY, tmp, overwrite=True)
logger.debug(tmp)
def info_string(autoscaler, nodes):
suffix = ""
if autoscaler.updaters:
suffix += " ({} updating)".format(len(autoscaler.updaters))
if autoscaler.num_failed_updates:
suffix += " ({} failed to update)".format(
len(autoscaler.num_failed_updates))
return "{} nodes{}".format(len(nodes), suffix)
+100 -16
View File
@@ -1,16 +1,26 @@
from collections import namedtuple
from functools import reduce
import logging
import time
from typing import Dict, List
import numpy as np
import ray._private.services as services
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES
from ray.autoscaler._private.constants import MEMORY_RESOURCE_UNIT_BYTES,\
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
from ray.autoscaler._private.util import add_resources, freq_of_dicts
from ray.gcs_utils import PlacementGroupTableData
from ray.autoscaler._private.resource_demand_scheduler import \
NodeIP, ResourceDict
from ray.core.generated.common_pb2 import PlacementStrategy
logger = logging.getLogger(__name__)
LoadMetricsSummary = namedtuple("LoadMetricsSummary", [
"head_ip", "usage", "resource_demand", "pg_demand", "request_demand",
"node_types"
])
class LoadMetrics:
"""Container for cluster load metrics.
@@ -31,6 +41,7 @@ class LoadMetrics:
self.waiting_bundles = []
self.infeasible_bundles = []
self.pending_placement_groups = []
self.resource_requests = []
def update(self,
ip: str,
@@ -72,34 +83,37 @@ class LoadMetrics:
def mark_active(self, ip):
assert ip is not None, "IP should be known at this time"
logger.info("Node {} is newly setup, treating as active".format(ip))
logger.debug("Node {} is newly setup, treating as active".format(ip))
self.last_heartbeat_time_by_ip[ip] = time.time()
def is_active(self, ip):
return ip in self.last_heartbeat_time_by_ip
def prune_active_ips(self, active_ips):
active_ips = set(active_ips)
active_ips.add(self.local_ip)
def prune(mapping):
def prune(mapping, should_log):
unwanted = set(mapping) - active_ips
for unwanted_key in unwanted:
# TODO (Alex): Change this back to info after #12138.
logger.debug("LoadMetrics: "
"Removed mapping: {} - {}".format(
unwanted_key, mapping[unwanted_key]))
if should_log:
logger.info("LoadMetrics: "
"Removed mapping: {} - {}".format(
unwanted_key, mapping[unwanted_key]))
del mapping[unwanted_key]
if unwanted:
if unwanted and should_log:
# TODO (Alex): Change this back to info after #12138.
logger.debug(
logger.info(
"LoadMetrics: "
"Removed {} stale ip mappings: {} not in {}".format(
len(unwanted), unwanted, active_ips))
assert not (unwanted & set(mapping))
prune(self.last_used_time_by_ip)
prune(self.static_resources_by_ip)
prune(self.dynamic_resources_by_ip)
prune(self.resource_load_by_ip)
prune(self.last_heartbeat_time_by_ip)
prune(self.last_used_time_by_ip, should_log=True)
prune(self.static_resources_by_ip, should_log=False)
prune(self.dynamic_resources_by_ip, should_log=False)
prune(self.resource_load_by_ip, should_log=False)
prune(self.last_heartbeat_time_by_ip, should_log=False)
def get_node_resources(self):
"""Return a list of node resources (static resource sizes).
@@ -155,12 +169,82 @@ class LoadMetrics:
return resources_used, resources_total
def get_resource_demand_vector(self):
return self.waiting_bundles + self.infeasible_bundles
def get_resource_demand_vector(self, clip=True):
if clip:
# Bound the total number of bundles to
# 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE. This guarantees the resource
# demand scheduler bin packing algorithm takes a reasonable amount
# of time to run.
return (
self.
waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE] +
self.
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
)
else:
return self.waiting_bundles + self.infeasible_bundles
def get_resource_requests(self):
return self.resource_requests
def get_pending_placement_groups(self):
return self.pending_placement_groups
def summary(self):
available_resources = reduce(add_resources,
self.dynamic_resources_by_ip.values()
) if self.dynamic_resources_by_ip else {}
total_resources = reduce(add_resources,
self.static_resources_by_ip.values()
) if self.static_resources_by_ip else {}
usage_dict = {}
for key in total_resources:
total = total_resources[key]
usage_dict[key] = (total - available_resources[key], total)
summarized_demand_vector = freq_of_dicts(
self.get_resource_demand_vector(clip=False))
summarized_resource_requests = freq_of_dicts(
self.get_resource_requests())
def placement_group_serializer(pg):
bundles = tuple(
frozenset(bundle.unit_resources.items())
for bundle in pg.bundles)
return (bundles, pg.strategy)
def placement_group_deserializer(pg_tuple):
# We marshal this as a dictionary so that we can easily json.dumps
# it later.
# TODO (Alex): Would there be a benefit to properly
# marshalling this (into a protobuf)?
bundles = list(map(dict, pg_tuple[0]))
return {
"bundles": freq_of_dicts(bundles),
"strategy": PlacementStrategy.Name(pg_tuple[1])
}
summarized_placement_groups = freq_of_dicts(
self.get_pending_placement_groups(),
serializer=placement_group_serializer,
deserializer=placement_group_deserializer)
nodes_summary = freq_of_dicts(self.static_resources_by_ip.values())
return LoadMetricsSummary(
head_ip=self.local_ip,
usage=usage_dict,
resource_demand=summarized_demand_vector,
pg_demand=summarized_placement_groups,
request_demand=summarized_resource_requests,
node_types=nodes_summary)
def set_resource_requests(self, requested_resources):
if requested_resources is not None:
assert isinstance(requested_resources, list), requested_resources
self.resource_requests = [
request for request in requested_resources if len(request) > 0
]
def info_string(self):
return " - " + "\n - ".join(
["{}: {}".format(k, v) for k, v in sorted(self._info().items())])
@@ -47,16 +47,19 @@ class ResourceDemandScheduler:
provider: NodeProvider,
node_types: Dict[NodeType, NodeTypeConfigDict],
max_workers: int,
head_node_type: NodeType,
upscaling_speed: float = 1) -> None:
self.provider = provider
self.node_types = copy.deepcopy(node_types)
self.max_workers = max_workers
self.head_node_type = head_node_type
self.upscaling_speed = upscaling_speed
def reset_config(self,
provider: NodeProvider,
node_types: Dict[NodeType, NodeTypeConfigDict],
max_workers: int,
head_node_type: NodeType,
upscaling_speed: float = 1) -> None:
"""Updates the class state variables.
@@ -89,6 +92,7 @@ class ResourceDemandScheduler:
self.provider = provider
self.node_types = copy.deepcopy(final_node_types)
self.max_workers = max_workers
self.head_node_type = head_node_type
self.upscaling_speed = upscaling_speed
def is_legacy_yaml(self,
@@ -145,18 +149,18 @@ class ResourceDemandScheduler:
node_resources, node_type_counts = self.calculate_node_resources(
nodes, launching_nodes, unused_resources_by_ip)
logger.info("Cluster resources: {}".format(node_resources))
logger.info("Node counts: {}".format(node_type_counts))
logger.debug("Cluster resources: {}".format(node_resources))
logger.debug("Node counts: {}".format(node_type_counts))
# Step 2: add nodes to add to satisfy min_workers for each type
(node_resources,
node_type_counts,
adjusted_min_workers) = \
_add_min_workers_nodes(
node_resources, node_type_counts, self.node_types,
self.max_workers, ensure_min_cluster_size)
self.max_workers, self.head_node_type, ensure_min_cluster_size)
# Step 3: add nodes for strict spread groups
logger.info(f"Placement group demands: {pending_placement_groups}")
logger.debug(f"Placement group demands: {pending_placement_groups}")
placement_group_demand_vector, strict_spreads = \
placement_groups_to_resource_demands(pending_placement_groups)
resource_demands.extend(placement_group_demand_vector)
@@ -183,12 +187,13 @@ class ResourceDemandScheduler:
# groups
unfulfilled, _ = get_bin_pack_residual(node_resources,
resource_demands)
logger.info("Resource demands: {}".format(resource_demands))
logger.info("Unfulfilled demands: {}".format(unfulfilled))
logger.debug("Resource demands: {}".format(resource_demands))
logger.debug("Unfulfilled demands: {}".format(unfulfilled))
# Add 1 to account for the head node.
max_to_add = self.max_workers + 1 - sum(node_type_counts.values())
nodes_to_add_based_on_demand = get_nodes_for(
self.node_types, node_type_counts, max_to_add, unfulfilled)
self.node_types, node_type_counts, self.head_node_type, max_to_add,
unfulfilled)
# Merge nodes to add based on demand and nodes to add based on
# min_workers constraint. We add them because nodes to add based on
# demand was calculated after the min_workers constraint was respected.
@@ -206,7 +211,7 @@ class ResourceDemandScheduler:
total_nodes_to_add, unused_resources_by_ip.keys(), nodes,
launching_nodes, adjusted_min_workers)
logger.info("Node requests: {}".format(total_nodes_to_add))
logger.debug("Node requests: {}".format(total_nodes_to_add))
return total_nodes_to_add
def _legacy_worker_node_to_launch(
@@ -443,6 +448,7 @@ class ResourceDemandScheduler:
to_launch = get_nodes_for(
self.node_types,
node_type_counts,
self.head_node_type,
max_to_add,
unfulfilled,
strict_spread=True)
@@ -490,7 +496,7 @@ def _add_min_workers_nodes(
node_resources: List[ResourceDict],
node_type_counts: Dict[NodeType, int],
node_types: Dict[NodeType, NodeTypeConfigDict], max_workers: int,
ensure_min_cluster_size: List[ResourceDict]
head_node_type: NodeType, ensure_min_cluster_size: List[ResourceDict]
) -> (List[ResourceDict], Dict[NodeType, int], Dict[NodeType, int]):
"""Updates resource demands to respect the min_workers and
request_resources() constraints.
@@ -515,6 +521,9 @@ def _add_min_workers_nodes(
existing = node_type_counts.get(node_type, 0)
target = min(
config.get("min_workers", 0), config.get("max_workers", 0))
if node_type == head_node_type:
# Add 1 to account for head node.
target = target + 1
if existing < target:
total_nodes_to_add_dict[node_type] = target - existing
node_type_counts[node_type] = target
@@ -537,7 +546,7 @@ def _add_min_workers_nodes(
max_node_resources, ensure_min_cluster_size)
# Get the nodes to meet the unfulfilled.
nodes_to_add_request_resources = get_nodes_for(
node_types, node_type_counts, max_to_add,
node_types, node_type_counts, head_node_type, max_to_add,
resource_requests_unfulfilled)
# Update the resources, counts and total nodes to add.
for node_type in nodes_to_add_request_resources:
@@ -558,6 +567,7 @@ def _add_min_workers_nodes(
def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
existing_nodes: Dict[NodeType, int],
head_node_type: NodeType,
max_to_add: int,
resources: List[ResourceDict],
strict_spread: bool = False) -> Dict[NodeType, int]:
@@ -581,9 +591,13 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
while resources and sum(nodes_to_add.values()) < max_to_add:
utilization_scores = []
for node_type in node_types:
max_workers_of_node_type = node_types[node_type].get(
"max_workers", 0)
if head_node_type == node_type:
# Add 1 to account for head node.
max_workers_of_node_type = max_workers_of_node_type + 1
if (existing_nodes.get(node_type, 0) + nodes_to_add.get(
node_type, 0) >= node_types[node_type].get(
"max_workers", 0)):
node_type, 0) >= max_workers_of_node_type):
continue
node_resources = node_types[node_type]["resources"]
if strict_spread:
@@ -601,8 +615,14 @@ def get_nodes_for(node_types: Dict[NodeType, NodeTypeConfigDict],
# starts up because placement groups are scheduled via custom
# resources. This will behave properly with the current utilization
# score heuristic, but it's a little dangerous and misleading.
logger.info(
"No feasible node type to add for {}".format(resources))
logger.warning(
f"The autoscaler could not find a node type to satisfy the"
f"request: {resources}. If this request is related to "
f"placement groups the resource request will resolve itself, "
f"otherwise please specify a node type with the necessary "
f"resource "
f"https://docs.ray.io/en/master/cluster/autoscaling.html#multiple-node-type-autoscaling." # noqa: E501
)
break
utilization_scores = sorted(utilization_scores, reverse=True)
+174 -1
View File
@@ -1,13 +1,15 @@
import collections
from datetime import datetime
import logging
import hashlib
import json
import jsonschema
import os
import threading
from typing import Any, Dict
from typing import Any, Dict, List
import ray
import ray.ray_constants
import ray._private.services as services
from ray.autoscaler._private.providers import _get_default_config
from ray.autoscaler._private.docker import validate_docker_config
@@ -20,6 +22,7 @@ RAY_SCHEMA_PATH = os.path.join(
# Internal kv keys for storing debug status.
DEBUG_AUTOSCALING_ERROR = "__autoscaling_error"
DEBUG_AUTOSCALING_STATUS = "__autoscaling_status"
DEBUG_AUTOSCALING_STATUS_LEGACY = "__autoscaling_status_legacy"
logger = logging.getLogger(__name__)
@@ -246,6 +249,47 @@ def hash_runtime_conf(file_mounts,
return (_hash_cache[conf_str], file_mounts_contents_hash)
def add_resources(dict1: Dict[str, float],
dict2: Dict[str, float]) -> Dict[str, float]:
"""Add the values in two dictionaries.
Returns:
dict: A new dictionary (inputs remain unmodified).
"""
new_dict = dict1.copy()
for k, v in dict2.items():
new_dict[k] = v + new_dict.get(k, 0)
return new_dict
def freq_of_dicts(dicts: List[Dict],
serializer=lambda d: frozenset(d.items()),
deserializer=dict):
"""Count a list of dictionaries (or unhashable types).
This is somewhat annoying because mutable data structures aren't hashable,
and set/dict keys must be hashable.
Args:
dicts (List[D]): A list of dictionaries to be counted.
serializer (D -> S): A custom serailization function. The output type S
must be hashable. The default serializer converts a dictionary into
a frozenset of KV pairs.
deserializer (S -> U): A custom deserialization function. See the
serializer for information about type S. For dictionaries U := D.
Returns:
List[Tuple[U, int]]: Returns a list of tuples. Each entry in the list
is a tuple containing a unique entry from `dicts` and its
corresponding frequency count.
"""
freqs = collections.Counter(map(lambda d: serializer(d), dicts))
as_list = []
for as_set, count in freqs.items():
as_list.append((deserializer(as_set), count))
return as_list
def add_prefix(info_string, prefix):
"""Prefixes each line of info_string, except the first, by prefix."""
lines = info_string.split("\n")
@@ -255,3 +299,132 @@ def add_prefix(info_string, prefix):
prefixed_lines.append(prefixed_line)
prefixed_info_string = "\n".join(prefixed_lines)
return prefixed_info_string
def format_pg(pg):
strategy = pg["strategy"]
bundles = pg["bundles"]
shape_strs = []
for bundle, count in bundles:
shape_strs.append(f"{bundle} * {count}")
bundles_str = ", ".join(shape_strs)
return f"{bundles_str} ({strategy})"
def get_usage_report(lm_summary):
usage_lines = []
for resource, (used, total) in lm_summary.usage.items():
line = f" {used}/{total} {resource}"
if resource in ["memory", "object_store_memory"]:
to_GiB = ray.ray_constants.MEMORY_RESOURCE_UNIT_BYTES / 2**30
used *= to_GiB
total *= to_GiB
line = f" {used:.2f}/{total:.3f} GiB {resource}"
usage_lines.append(line)
usage_report = "\n".join(usage_lines)
return usage_report
def get_demand_report(lm_summary):
demand_lines = []
for bundle, count in lm_summary.resource_demand:
line = f" {bundle}: {count}+ pending tasks/actors"
demand_lines.append(line)
for entry in lm_summary.pg_demand:
pg, count = entry
pg_str = format_pg(pg)
line = f" {pg_str}: {count}+ pending placement groups"
demand_lines.append(line)
for bundle, count in lm_summary.request_demand:
line = f" {bundle}: {count}+ from request_resources()"
demand_lines.append(line)
if len(demand_lines) > 0:
demand_report = "\n".join(demand_lines)
else:
demand_report = " (no resource demands)"
return demand_report
def format_info_string(lm_summary, autoscaler_summary, time=None):
if time is None:
time = datetime.now()
header = "=" * 8 + f" Autoscaler status: {time} " + "=" * 8
separator = "-" * len(header)
available_node_report_lines = []
for node_type, count in autoscaler_summary.active_nodes.items():
line = f" {count} {node_type}"
available_node_report_lines.append(line)
available_node_report = "\n".join(available_node_report_lines)
pending_lines = []
for node_type, count in autoscaler_summary.pending_launches.items():
line = f" {node_type}, {count} launching"
pending_lines.append(line)
for ip, node_type in autoscaler_summary.pending_nodes:
line = f" {ip}: {node_type}, setting up"
pending_lines.append(line)
if pending_lines:
pending_report = "\n".join(pending_lines)
else:
pending_report = " (no pending nodes)"
failure_lines = []
for ip, node_type in autoscaler_summary.failed_nodes:
line = f" {ip}: {node_type}"
failure_report = "Recent failures:\n"
if failure_lines:
failure_report += "\n".join(failure_lines)
else:
failure_report += " (no failures)"
usage_report = get_usage_report(lm_summary)
demand_report = get_demand_report(lm_summary)
formatted_output = f"""{header}
Node status
{separator}
Healthy:
{available_node_report}
Pending:
{pending_report}
{failure_report}
Resources
{separator}
Usage:
{usage_report}
Demands:
{demand_report}"""
return formatted_output
def format_info_string_no_node_types(lm_summary, time=None):
if time is None:
time = datetime.now()
header = "=" * 8 + f" Cluster status: {time} " + "=" * 8
separator = "-" * len(header)
node_lines = []
for node_type, count in lm_summary.node_types:
line = f" {count} node(s) with resources: {node_type}"
node_lines.append(line)
node_report = "\n".join(node_lines)
usage_report = get_usage_report(lm_summary)
demand_report = get_demand_report(lm_summary)
formatted_output = f"""{header}
Node status
{separator}
{node_report}
Resources
{separator}
Usage:
{usage_report}
Demands:
{demand_report}"""
return formatted_output
@@ -250,9 +250,11 @@ worker_nodes:
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
# "~/path1/on/remote/machine": "/path1/on/local/machine",
# "~/path2/on/remote/machine": "/path2/on/local/machine",
}
# Note that the container images in this example have a non-root user.
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
# Files or directories to copy from the head node to the worker nodes. The format is a
# list of paths. The same path on the head node will be copied to the worker node.
@@ -250,9 +250,11 @@ worker_nodes:
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
# "~/path1/on/remote/machine": "/path1/on/local/machine",
# "~/path2/on/remote/machine": "/path2/on/local/machine",
}
# Note that the container images in this example have a non-root user.
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
# Files or directories to copy from the head node to the worker nodes. The format is a
# list of paths. The same path on the head node will be copied to the worker node.
@@ -286,9 +286,11 @@ worker_nodes:
# Files or directories to copy to the head and worker nodes. The format is a
# dictionary from REMOTE_PATH: LOCAL_PATH, e.g.
file_mounts: {
# "/path1/on/remote/machine": "/path1/on/local/machine",
# "/path2/on/remote/machine": "/path2/on/local/machine",
# "~/path1/on/remote/machine": "/path1/on/local/machine",
# "~/path2/on/remote/machine": "/path2/on/local/machine",
}
# Note that the container images in this example have a non-root user.
# To avoid permissions issues, we recommend mounting into a subdirectory of home (~).
# List of commands that will be run before `setup_commands`. If docker is
# enabled, these commands will run outside the container and before docker
+4 -2
View File
@@ -142,7 +142,8 @@ class WorkerCrashedError(RayError):
"""Indicates that the worker died unexpectedly while executing a task."""
def __str__(self):
return "The worker died unexpectedly while executing this task."
return ("The worker died unexpectedly while executing this task. "
"Check python-core-worker-*.log files for more information.")
class RayActorError(RayError):
@@ -153,7 +154,8 @@ class RayActorError(RayError):
"""
def __str__(self):
return "The actor died unexpectedly before finishing this task."
return ("The actor died unexpectedly before finishing this task. "
"Check python-core-worker-*.log files for more information.")
class RaySystemError(RayError):
-2
View File
@@ -1,6 +1,4 @@
from .dynamic_resources import set_resource
from .object_spilling import force_spill_objects
__all__ = [
"set_resource",
"force_spill_objects",
]
+82 -99
View File
@@ -1,117 +1,100 @@
from ray.experimental.client.api import ClientAPI
from ray.experimental.client.api import APIImpl
from typing import Optional, List, Tuple
from contextlib import contextmanager
from typing import List, Tuple
import logging
logger = logging.getLogger(__name__)
# About these global variables: Ray 1.0 uses exported module functions to
# provide its API, and we need to match that. However, we want different
# behaviors depending on where, exactly, in the client stack this is running.
#
# The reason for these differences depends on what's being pickled and passed
# to functions, or functions inside functions. So there are three cases to care
# about
#
# (Python Client)-->(Python ClientServer)-->(Internal Raylet Process)
#
# * _client_api should be set if we're inside the client
# * _server_api should be set if we're inside the clientserver
# * Both will be set if we're running both (as in a test)
# * Neither should be set if we're inside the raylet (but we still need to shim
# from the client API surface to the Ray API)
#
# The job of RayAPIStub (below) delegates to the appropriate one of these
# depending on what's set or not. Then, all users importing the ray object
# from this package get the stub which routes them to the appropriate APIImpl.
_client_api: Optional[APIImpl] = None
_server_api: Optional[APIImpl] = None
# The reason for _is_server is a hack around the above comment while running
# tests. If we have both a client and a server trying to control these static
# variables then we need a way to decide which to use. In this case, both
# _client_api and _server_api are set.
# This boolean flips between the two
_is_server: bool = False
@contextmanager
def stash_api_for_tests(in_test: bool):
global _is_server
is_server = _is_server
if in_test:
_is_server = True
yield _server_api
if in_test:
_is_server = is_server
def _set_client_api(val: Optional[APIImpl]):
global _client_api
global _is_server
if _client_api is not None:
raise Exception("Trying to set more than one client API")
_client_api = val
_is_server = False
def _set_server_api(val: Optional[APIImpl]):
global _server_api
global _is_server
if _server_api is not None:
raise Exception("Trying to set more than one server API")
_server_api = val
_is_server = True
def reset_api():
global _client_api
global _server_api
global _is_server
_client_api = None
_server_api = None
_is_server = False
def _get_client_api() -> APIImpl:
global _client_api
global _server_api
global _is_server
api = None
if _is_server:
api = _server_api
else:
api = _client_api
if api is None:
# We're inside a raylet worker
from ray.experimental.client.server.core_ray_api import CoreRayAPI
return CoreRayAPI()
return api
class RayAPIStub:
"""This class stands in as the replacement API for the `import ray` module.
Much like the ray module, this mostly delegates the work to the
_client_worker. As parts of the ray API are covered, they are piped through
here or on the client worker API.
"""
def __init__(self):
from ray.experimental.client.api import ClientAPI
self.api = ClientAPI()
self.client_worker = None
self._server = None
self._connected_with_init = False
self._inside_client_test = False
def connect(self,
conn_str: str,
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
metadata: List[Tuple[str, str]] = None) -> None:
"""Connect the Ray Client to a server.
Args:
conn_str: Connection string, in the form "[host]:port"
secure: Whether to use a TLS secured gRPC channel
metadata: gRPC metadata to send on connect
"""
# Delay imports until connect to avoid circular imports.
from ray.experimental.client.worker import Worker
_client_worker = Worker(
conn_str, secure=secure, metadata=metadata, stub=stub)
_set_client_api(ClientAPI(_client_worker))
import ray._private.client_mode_hook
if self.client_worker is not None:
if self._connected_with_init:
return
raise Exception(
"ray.connect() called, but ray client is already connected")
if not self._inside_client_test:
# If we're calling a client connect specifically and we're not
# currently in client mode, ensure we are.
ray._private.client_mode_hook._explicitly_enable_client_mode()
self.client_worker = Worker(conn_str, secure=secure, metadata=metadata)
self.api.worker = self.client_worker
def disconnect(self):
global _client_api
if _client_api is not None:
_client_api.close()
_client_api = None
"""Disconnect the Ray Client.
"""
if self.client_worker is not None:
self.client_worker.close()
self.client_worker = None
# remote can be called outside of a connection, which is why it
# exists on the same API layer as connect() itself.
def remote(self, *args, **kwargs):
"""remote is the hook stub passed on to replace `ray.remote`.
This sets up remote functions or actors, as the decorator,
but does not execute them.
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
return self.api.remote(*args, **kwargs)
def __getattr__(self, key: str):
global _get_client_api
api = _get_client_api()
return getattr(api, key)
if not self.is_connected():
raise Exception("Ray Client is not connected. "
"Please connect by calling `ray.connect`.")
return getattr(self.api, key)
def is_connected(self) -> bool:
return self.api is not None
def init(self, *args, **kwargs):
if self._server is not None:
raise Exception("Trying to start two instances of ray via client")
import ray.experimental.client.server.server as ray_client_server
self._server, address_info = ray_client_server.init_and_serve(
"localhost:50051", *args, **kwargs)
self.connect("localhost:50051")
self._connected_with_init = True
return address_info
def shutdown(self, _exiting_interpreter=False):
self.disconnect()
import ray.experimental.client.server.server as ray_client_server
if self._server is None:
return
ray_client_server.shutdown_with_server(self._server,
_exiting_interpreter)
self._server = None
ray = RayAPIStub()
+82 -96
View File
@@ -1,74 +1,51 @@
# This file defines an interface and client-side API stub
# for referring either to the core Ray API or the same interface
# from the Ray client.
#
# In tandem with __init__.py, we want to expose an API that's
# close to `python/ray/__init__.py` but with more than one implementation.
# The stubs in __init__ should call into a well-defined interface.
# Only the core Ray API implementation should actually `import ray`
# (and thus import all the raylet worker C bindings and such).
# But to make sure that we're matching these calls, we define this API.
from abc import ABC
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Union, Optional
import ray.core.generated.ray_client_pb2 as ray_client_pb2
"""This file defines the interface between the ray client worker
and the overall ray module API.
"""
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientStub
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientObjectRef
from ray._raylet import ObjectRef
# Use the imports for type checking. This is a python 3.6 limitation.
# See https://www.python.org/dev/peps/pep-0563/
PutType = Union[ClientObjectRef, ObjectRef]
class APIImpl(ABC):
"""
APIImpl is the interface to implement for whichever version of the core
Ray API that needs abstracting when run in client mode.
class ClientAPI:
"""The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
"""
@abstractmethod
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
"""
get is the hook stub passed on to replace `ray.get`
def __init__(self, worker=None):
self.worker = worker
def get(self, vals, *, timeout=None):
"""get is the hook stub passed on to replace `ray.get`
Args:
vals: [Client]ObjectRef or list of these refs to retrieve.
timeout: Optional timeout in milliseconds
"""
pass
return self.worker.get(vals, timeout=timeout)
@abstractmethod
def put(self, vals: Any, *args,
**kwargs) -> Union["ClientObjectRef", "ObjectRef"]:
"""
put is the hook stub passed on to replace `ray.put`
def put(self, *args, **kwargs):
"""put is the hook stub passed on to replace `ray.put`
Args:
vals: The value or list of values to `put`.
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
return self.worker.put(*args, **kwargs)
@abstractmethod
def wait(self, *args, **kwargs):
"""
wait is the hook stub passed on to replace `ray.wait`
"""wait is the hook stub passed on to replace `ray.wait`
Args:
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
return self.worker.wait(*args, **kwargs)
@abstractmethod
def remote(self, *args, **kwargs):
"""
remote is the hook stub passed on to replace `ray.remote`.
"""remote is the hook stub passed on to replace `ray.remote`.
This sets up remote functions or actors, as the decorator,
but does not execute them.
@@ -77,12 +54,24 @@ class APIImpl(ABC):
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
# Delayed import to avoid a cyclic import
from ray.experimental.client.common import remote_decorator
if len(args) == 1 and len(kwargs) == 0 and callable(args[0]):
# This is the case where the decorator is just @ray.remote.
return remote_decorator(options=None)(args[0])
error_string = ("The @ray.remote decorator must be applied either "
"with no arguments and no parentheses, for example "
"'@ray.remote', or it must be applied using some of "
"the arguments 'num_returns', 'num_cpus', 'num_gpus', "
"'memory', 'object_store_memory', 'resources', "
"'max_calls', or 'max_restarts', like "
"'@ray.remote(num_returns=2, "
"resources={\"CustomResource\": 1})'.")
assert len(args) == 0 and len(kwargs) > 0, error_string
return remote_decorator(options=kwargs)
@abstractmethod
def call_remote(self, instance: "ClientStub", *args, **kwargs):
"""
call_remote is called by stub objects to execute them remotely.
"""call_remote is called by stub objects to execute them remotely.
This is used by stub objects in situations where they're called
with .remote, eg, `f.remote()` or `actor_cls.remote()`.
@@ -95,31 +84,57 @@ class APIImpl(ABC):
args: opaque arguments
kwargs: opaque keyword arguments
"""
pass
return self.worker.call_remote(instance, *args, **kwargs)
@abstractmethod
def close(self) -> None:
def call_release(self, id: bytes) -> None:
"""Attempts to release an object reference.
When client references are destructed, they release their reference,
which can opportunistically send a notification through the datachannel
to release the reference being held for that object on the server.
Args:
id: The id of the reference to release on the server side.
"""
close cleans up an API connection by closing any channels or
return self.worker.call_release(id)
def call_retain(self, id: bytes) -> None:
"""Attempts to retain a client object reference.
Increments the reference count on the client side, to prevent
the client worker from attempting to release the server reference.
Args:
id: The id of the reference to retain on the client side.
"""
return self.worker.call_retain(id)
def close(self) -> None:
"""close cleans up an API connection by closing any channels or
shutting down any servers gracefully.
"""
pass
return self.worker.close()
@abstractmethod
def kill(self, actor, *, no_restart=True):
def get_actor(self, name: str) -> "ClientActorHandle":
"""Returns a handle to an actor by name.
Args:
name: The name passed to this actor by
Actor.options(name="name").remote()
"""
kill forcibly stops an actor running in the cluster
return self.worker.get_actor(name)
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
"""kill forcibly stops an actor running in the cluster
Args:
no_restart: Whether this actor should be restarted if it's a
restartable actor.
"""
pass
return self.worker.terminate_actor(actor, no_restart)
@abstractmethod
def cancel(self, obj, *, force=False, recursive=True):
"""
Cancels a task on the cluster.
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
"""Cancels a task on the cluster.
If the specified task is pending execution, it will not be executed. If
the task is currently executing, the behavior depends on the ``force``
@@ -136,46 +151,11 @@ class APIImpl(ABC):
recursive (boolean): Whether to try to cancel tasks submitted by
the task specified.
"""
pass
class ClientAPI(APIImpl):
"""
The Client-side methods corresponding to the ray API. Delegates
to the Client Worker that contains the connection to the ClientServer.
"""
def __init__(self, worker):
self.worker = worker
def get(self, vals, *, timeout=None):
return self.worker.get(vals, timeout=timeout)
def put(self, *args, **kwargs):
return self.worker.put(*args, **kwargs)
def wait(self, *args, **kwargs):
return self.worker.wait(*args, **kwargs)
def remote(self, *args, **kwargs):
return self.worker.remote(*args, **kwargs)
def call_remote(self, instance: "ClientStub", *args, **kwargs):
return self.worker.call_remote(instance, *args, **kwargs)
def close(self) -> None:
return self.worker.close()
def kill(self, actor: "ClientActorHandle", *, no_restart=True):
return self.worker.terminate_actor(actor, no_restart)
def cancel(self, obj: "ClientObjectRef", *, force=False, recursive=True):
return self.worker.terminate_task(obj, force, recursive)
# Various metadata methods for the client that are defined in the protocol.
def is_initialized(self) -> bool:
""" True if our client is connected, and if the server is initialized.
"""True if our client is connected, and if the server is initialized.
Returns:
A boolean determining if the client is connected and
server initialized.
@@ -188,6 +168,8 @@ class ClientAPI(APIImpl):
Returns:
Information about the Ray clients in the cluster.
"""
# This should be imported here, otherwise, it will error doc build.
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.NODES)
@@ -201,6 +183,8 @@ class ClientAPI(APIImpl):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
# This should be imported here, otherwise, it will error doc build.
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES)
@@ -216,6 +200,8 @@ class ClientAPI(APIImpl):
A dictionary mapping resource name to the total quantity of that
resource in the cluster.
"""
# This should be imported here, otherwise, it will error doc build.
import ray.core.generated.ray_client_pb2 as ray_client_pb2
return self.worker.get_cluster_info(
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES)
@@ -0,0 +1,182 @@
"""Implements the client side of the client/server pickling protocol.
All ray client client/server data transfer happens through this pickling
protocol. The model is as follows:
* All Client objects (eg ClientObjectRef) always live on the client and
are never represented in the server
* All Ray objects (eg, ray.ObjectRef) always live on the server and are
never returned to the client
* In order to translate between these two references, PickleStub tuples
are generated as persistent ids in the data blobs during the pickling
and unpickling of these objects.
The PickleStubs have just enough information to find or generate their
associated partner object on either side.
This also has the advantage of avoiding predefined pickle behavior for ray
objects, which may include ray internal reference counting.
ClientPickler dumps things from the client into the appropriate stubs
ServerUnpickler loads stubs from the server into their client counterparts.
"""
import cloudpickle
import io
import sys
from typing import NamedTuple
from typing import Any
from typing import Dict
from typing import Optional
from ray.experimental.client import RayAPIStub
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientRemoteMethod
from ray.experimental.client.common import OptionWrapper
from ray.experimental.client.common import SelfReferenceSentinel
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray._private.client_mode_hook import disable_client_hook
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
# NOTE(barakmich): These PickleStubs are really close to
# the data for an exectuion, with no arguments. Combine the two?
PickleStub = NamedTuple("PickleStub",
[("type", str), ("client_id", str), ("ref_id", bytes),
("name", Optional[str]),
("baseline_options", Optional[Dict])])
class ClientPickler(cloudpickle.CloudPickler):
def __init__(self, client_id, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client_id = client_id
def persistent_id(self, obj):
if isinstance(obj, RayAPIStub):
return PickleStub(
type="Ray",
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientObjectRef):
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj.id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientActorHandle):
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ClientRemoteFunc):
# TODO(barakmich): This is going to have trouble with mutually
# recursive functions that haven't, as yet, been executed. It's
# relatively doable (keep track of intermediate refs in progress
# with ensure_ref and return appropriately) But punting for now.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteFuncSelfReference",
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
return PickleStub(
type="RemoteFunc",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
baseline_options=obj._options,
)
elif isinstance(obj, ClientActorClass):
# TODO(barakmich): Mutual recursion, as above.
if obj._ref is None:
obj._ensure_ref()
if type(obj._ref) == SelfReferenceSentinel:
return PickleStub(
type="RemoteActorSelfReference",
client_id=self.client_id,
ref_id=b"",
name=None,
baseline_options=None,
)
return PickleStub(
type="RemoteActor",
client_id=self.client_id,
ref_id=obj._ref.id,
name=None,
baseline_options=obj._options,
)
elif isinstance(obj, ClientRemoteMethod):
return PickleStub(
type="RemoteMethod",
client_id=self.client_id,
ref_id=obj.actor_handle.actor_ref.id,
name=obj.method_name,
baseline_options=None,
)
elif isinstance(obj, OptionWrapper):
raise NotImplementedError(
"Sending a partial option is unimplemented")
return None
class ServerUnpickler(pickle.Unpickler):
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Object":
return ClientObjectRef(id=pid.ref_id)
elif pid.type == "Actor":
return ClientActorHandle(ClientActorRef(id=pid.ref_id))
else:
raise NotImplementedError("Being passed back an unknown stub")
def dumps_from_client(obj: Any, client_id: str, protocol=None) -> bytes:
with disable_client_hook():
with io.BytesIO() as file:
cp = ClientPickler(client_id, file, protocol=protocol)
cp.dump(obj)
return file.getvalue()
def loads_from_server(data: bytes,
*,
fix_imports=True,
encoding="ASCII",
errors="strict") -> Any:
if isinstance(data, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(data)
return ServerUnpickler(
file, fix_imports=fix_imports, encoding=encoding,
errors=errors).load()
def convert_to_arg(val: Any, client_id: str) -> ray_client_pb2.Arg:
out = ray_client_pb2.Arg()
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = dumps_from_client(val, client_id)
return out
+151 -130
View File
@@ -1,16 +1,31 @@
import ray.core.generated.ray_client_pb2 as ray_client_pb2
from ray.experimental.client import ray
from typing import Any
from typing import Dict
from ray import cloudpickle
from ray.experimental.client.options import validate_options
import base64
import inspect
from ray.util.inspect import is_cython
import json
import threading
from typing import Any
from typing import List
from typing import Dict
from typing import Optional
from typing import Union
class ClientBaseRef:
def __init__(self, id, handle=None):
self.id = id
self.handle = handle
def __init__(self, id: bytes):
self.id = None
if not isinstance(id, bytes):
raise TypeError("ClientRefs must be created with bytes IDs")
self.id: bytes = id
ray.call_retain(id)
def binary(self):
return self.id
def __eq__(self, other):
return self.id == other.id
def __repr__(self):
return "%s(%s)" % (
@@ -18,20 +33,16 @@ class ClientBaseRef:
self.id.hex(),
)
def __eq__(self, other):
return self.id == other.id
def __hash__(self):
return hash(self.id)
def binary(self):
return self.id
@classmethod
def from_remote_ref(cls, ref: ray_client_pb2.RemoteRef):
return cls(id=ref.id, handle=ref.handle)
def __del__(self):
if ray.is_connected() and self.id is not None:
ray.call_release(self.id)
class ClientObjectRef(ClientBaseRef):
def _unpack_ref(self):
return cloudpickle.loads(self.handle)
pass
class ClientActorRef(ClientBaseRef):
@@ -43,8 +54,7 @@ class ClientStub:
class ClientRemoteFunc(ClientStub):
"""
A stub created on the Ray Client to represent a remote
"""A stub created on the Ray Client to represent a remote
function that can be exectued on the cluster.
This class is allowed to be passed around between remote functions.
@@ -53,55 +63,57 @@ class ClientRemoteFunc(ClientStub):
_func: The actual function to execute remotely
_name: The original name of the function
_ref: The ClientObjectRef of the pickled code of the function, _func
_raylet_remote: The Raylet-side ray.remote_function.RemoteFunction
for this object
"""
def __init__(self, f):
def __init__(self, f, options=None):
self._lock = threading.Lock()
self._func = f
self._name = f.__name__
self.id = None
# self._ref can be lazily instantiated. Rather than eagerly creating
# function data objects in the server we can put them just before we
# execute the function, especially in cases where many @ray.remote
# functions exist in a library and only a handful are ever executed by
# a user of the library.
#
# TODO(barakmich): This ref might actually be better as a serialized
# ObjectRef. This requires being able to serialize the ref without
# pinning it (as the lifetime of the ref is tied with the server, not
# the client)
self._ref = None
self._raylet_remote = None
self._options = validate_options(options)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote function cannot be called directly. "
"Use {self._name}.remote method instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
return return_refs(ray.call_remote(self, *args, **kwargs))
def _get_ray_remote_impl(self):
if self._raylet_remote is None:
self._raylet_remote = ray.remote(self._func)
return self._raylet_remote
def options(self, **kwargs):
return OptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def __repr__(self):
return "ClientRemoteFunc(%s, %s)" % (self._name, self._ref)
def _ensure_ref(self):
with self._lock:
if self._ref is None:
# While calling ray.put() on our function, if
# our function is recursive, it will attempt to
# encode the ClientRemoteFunc -- itself -- and
# infinitely recurse on _ensure_ref.
#
# So we set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self._func)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self._func)
self._ensure_ref()
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.FUNCTION
task.name = self._name
task.payload_id = self._ref.handle
task.payload_id = self._ref.id
set_task_options(task, self._options, "baseline_options")
return task
class ClientActorClass(ClientStub):
""" A stub created on the Ray Client to represent an actor class.
"""A stub created on the Ray Client to represent an actor class.
It is wrapped by ray.remote and can be executed on the cluster.
@@ -109,39 +121,40 @@ class ClientActorClass(ClientStub):
actor_cls: The actual class to execute remotely
_name: The original name of the class
_ref: The ClientObjectRef of the pickled `actor_cls`
_raylet_remote: The Raylet-side ray.ActorClass for this object
"""
def __init__(self, actor_cls):
def __init__(self, actor_cls, options=None):
self.actor_cls = actor_cls
self._name = actor_cls.__name__
self._ref = None
self._raylet_remote = None
self._options = validate_options(options)
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote actor cannot be instantiated directly. "
"Use {self._name}.remote() instead")
def __getstate__(self) -> Dict:
state = {
"actor_cls": self.actor_cls,
"_name": self._name,
"_ref": self._ref,
}
return state
def _ensure_ref(self):
if self._ref is None:
# As before, set the state of the reference to be an
# in-progress self reference value, which
# the encoding can detect and handle correctly.
self._ref = SelfReferenceSentinel()
self._ref = ray.put(self.actor_cls)
def __setstate__(self, state: Dict) -> None:
self.actor_cls = state["actor_cls"]
self._name = state["_name"]
self._ref = state["_ref"]
def remote(self, *args, **kwargs):
def remote(self, *args, **kwargs) -> "ClientActorHandle":
# Actually instantiate the actor
ref = ray.call_remote(self, *args, **kwargs)
return ClientActorHandle(ClientActorRef(ref.id, ref.handle), self)
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]))
def options(self, **kwargs):
return ActorOptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def __repr__(self):
return "ClientRemoteActor(%s, %s)" % (self._name, self._ref)
return "ClientActorClass(%s, %s)" % (self._name, self._ref)
def __getattr__(self, key):
if key not in self.__dict__:
@@ -149,12 +162,12 @@ class ClientActorClass(ClientStub):
raise NotImplementedError("static methods")
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
if self._ref is None:
self._ref = ray.put(self.actor_cls)
self._ensure_ref()
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.ACTOR
task.name = self._name
task.payload_id = self._ref.handle
task.payload_id = self._ref.id
set_task_options(task, self._options, "baseline_options")
return task
@@ -174,29 +187,12 @@ class ClientActorHandle(ClientStub):
ray.actor.ActorHandle contained in the actor_id ref.
"""
def __init__(self, actor_ref: ClientActorRef,
actor_class: ClientActorClass):
def __init__(self, actor_ref: ClientActorRef):
self.actor_ref = actor_ref
self.actor_class = actor_class
self._real_actor_handle = None
def _get_ray_remote_impl(self):
if self._real_actor_handle is None:
self._real_actor_handle = cloudpickle.loads(self.actor_ref.handle)
return self._real_actor_handle
def __getstate__(self) -> Dict:
state = {
"actor_ref": self.actor_ref,
"actor_class": self.actor_class,
"_real_actor_handle": self._real_actor_handle,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_ref = state["actor_ref"]
self.actor_class = state["actor_class"]
self._real_actor_handle = state["_real_actor_handle"]
def __del__(self) -> None:
if ray.is_connected():
ray.call_release(self.actor_ref.id)
@property
def _actor_id(self):
@@ -226,65 +222,90 @@ class ClientRemoteMethod(ClientStub):
def __call__(self, *args, **kwargs):
raise TypeError(f"Remote method cannot be called directly. "
"Use {self._name}.remote() instead")
def _get_ray_remote_impl(self):
return getattr(self.actor_handle._get_ray_remote_impl(),
self.method_name)
def __getstate__(self) -> Dict:
state = {
"actor_handle": self.actor_handle,
"method_name": self.method_name,
}
return state
def __setstate__(self, state: Dict) -> None:
self.actor_handle = state["actor_handle"]
self.method_name = state["method_name"]
f"Use {self._name}.remote() instead")
def remote(self, *args, **kwargs):
return ray.call_remote(self, *args, **kwargs)
return return_refs(ray.call_remote(self, *args, **kwargs))
def __repr__(self):
name = "%s.%s" % (self.actor_handle.actor_class._name,
self.method_name)
return "ClientRemoteMethod(%s, %s)" % (name,
self.actor_handle.actor_id)
return "ClientRemoteMethod(%s, %s)" % (self.method_name,
self.actor_handle)
def options(self, **kwargs):
return OptionWrapper(self, kwargs)
def _remote(self, args=[], kwargs={}, **option_args):
return self.options(**option_args).remote(*args, **kwargs)
def _prepare_client_task(self) -> ray_client_pb2.ClientTask:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.METHOD
task.name = self.method_name
task.payload_id = self.actor_handle.actor_ref.handle
task.payload_id = self.actor_handle.actor_ref.id
return task
def convert_from_arg(pb) -> Any:
if pb.local == ray_client_pb2.Arg.Locality.REFERENCE:
return ClientObjectRef(pb.reference_id)
elif pb.local == ray_client_pb2.Arg.Locality.INTERNED:
return cloudpickle.loads(pb.data)
class OptionWrapper:
def __init__(self, stub: ClientStub, options: Optional[Dict[str, Any]]):
self.remote_stub = stub
self.options = validate_options(options)
raise Exception("convert_from_arg: Uncovered locality enum")
def remote(self, *args, **kwargs):
return return_refs(ray.call_remote(self, *args, **kwargs))
def __getattr__(self, key):
return getattr(self.remote_stub, key)
def _prepare_client_task(self):
task = self.remote_stub._prepare_client_task()
set_task_options(task, self.options)
return task
def convert_to_arg(val):
out = ray_client_pb2.Arg()
if isinstance(val, ClientObjectRef):
out.local = ray_client_pb2.Arg.Locality.REFERENCE
out.reference_id = val.id
else:
out.local = ray_client_pb2.Arg.Locality.INTERNED
out.data = cloudpickle.dumps(val)
return out
class ActorOptionWrapper(OptionWrapper):
def remote(self, *args, **kwargs):
ref_ids = ray.call_remote(self, *args, **kwargs)
assert len(ref_ids) == 1
return ClientActorHandle(ClientActorRef(ref_ids[0]))
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def set_task_options(task: ray_client_pb2.ClientTask,
options: Optional[Dict[str, Any]],
field: str = "options") -> None:
if options is None:
task.ClearField(field)
return
options_str = json.dumps(options)
getattr(task, field).json_options = options_str
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return cloudpickle.loads(data)
def return_refs(ids: List[bytes]
) -> Union[None, ClientObjectRef, List[ClientObjectRef]]:
if len(ids) == 1:
return ClientObjectRef(ids[0])
if len(ids) == 0:
return None
return [ClientObjectRef(id) for id in ids]
class DataEncodingSentinel:
def __repr__(self) -> str:
return self.__class__.__name__
class SelfReferenceSentinel(DataEncodingSentinel):
pass
def remote_decorator(options: Optional[Dict[str, Any]]):
def decorator(function_or_class) -> ClientStub:
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class, options=options)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class, options=options)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
return decorator
@@ -0,0 +1,108 @@
"""This file implements a threaded stream controller to abstract a data stream
back to the ray clientserver.
"""
import logging
import queue
import threading
import grpc
from typing import Any
from typing import Dict
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
# The maximum field value for request_id -- which is also the maximum
# number of simultaneous in-flight requests.
INT32_MAX = (2**31) - 1
class DataClient:
def __init__(self, channel: "grpc._channel.Channel", client_id: str):
"""Initializes a thread-safe datapath over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
client_id: the generated ID representing this client
"""
self.channel = channel
self.request_queue = queue.Queue()
self.data_thread = self._start_datathread()
self.ready_data: Dict[int, Any] = {}
self.cv = threading.Condition()
self._req_id = 0
self._client_id = client_id
self.data_thread.start()
def _next_id(self) -> int:
self._req_id += 1
if self._req_id > INT32_MAX:
self._req_id = 1
# Responses that aren't tracked (like opportunistic releases)
# have req_id=0, so make sure we never mint such an id.
assert self._req_id != 0
return self._req_id
def _start_datathread(self) -> threading.Thread:
return threading.Thread(target=self._data_main, args=(), daemon=True)
def _data_main(self) -> None:
stub = ray_client_pb2_grpc.RayletDataStreamerStub(self.channel)
resp_stream = stub.Datapath(
iter(self.request_queue.get, None),
metadata=(("client_id", self._client_id), ))
try:
for response in resp_stream:
if response.req_id == 0:
# This is not being waited for.
logger.debug(f"Got unawaited response {response}")
continue
with self.cv:
self.ready_data[response.req_id] = response
self.cv.notify_all()
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED == e.code():
# Gracefully shutting down
logger.info("Cancelling data channel")
else:
logger.error(
f"Got Error from data channel -- shutting down: {e}")
raise e
def close(self) -> None:
if self.request_queue is not None:
self.request_queue.put(None)
if self.data_thread is not None:
self.data_thread.join()
def _blocking_send(self, req: ray_client_pb2.DataRequest
) -> ray_client_pb2.DataResponse:
req_id = self._next_id()
req.req_id = req_id
self.request_queue.put(req)
data = None
with self.cv:
self.cv.wait_for(lambda: req_id in self.ready_data)
data = self.ready_data[req_id]
del self.ready_data[req_id]
return data
def GetObject(self, request: ray_client_pb2.GetRequest,
context=None) -> ray_client_pb2.GetResponse:
datareq = ray_client_pb2.DataRequest(get=request, )
resp = self._blocking_send(datareq)
return resp.get
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
datareq = ray_client_pb2.DataRequest(put=request, )
resp = self._blocking_send(datareq)
return resp.put
def ReleaseObject(self,
request: ray_client_pb2.ReleaseRequest,
context=None) -> None:
datareq = ray_client_pb2.DataRequest(release=request, )
self.request_queue.put(datareq)
@@ -0,0 +1,7 @@
from ray.experimental.client import ray
from ray.tune import tune
ray.connect("localhost:50051")
tune.run("PG", config={"env": "CartPole-v0"})
@@ -0,0 +1,86 @@
"""This file implements a threaded stream controller to return logs back from
the ray clientserver.
"""
import sys
import logging
import queue
import threading
import grpc
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
# TODO(barakmich): Running a logger in a logger causes loopback.
# The client logger need its own root -- possibly this one.
# For the moment, let's just not propogate beyond this point.
logger.propagate = False
class LogstreamClient:
def __init__(self, channel: "grpc._channel.Channel"):
"""Initializes a thread-safe log stream over a Ray Client gRPC channel.
Args:
channel: connected gRPC channel
"""
self.channel = channel
self.request_queue = queue.Queue()
self.log_thread = self._start_logthread()
self.log_thread.start()
def _start_logthread(self) -> threading.Thread:
return threading.Thread(target=self._log_main, args=(), daemon=True)
def _log_main(self) -> None:
stub = ray_client_pb2_grpc.RayletLogStreamerStub(self.channel)
log_stream = stub.Logstream(iter(self.request_queue.get, None))
try:
for record in log_stream:
if record.level < 0:
self.stdstream(level=record.level, msg=record.msg)
self.log(level=record.level, msg=record.msg)
except grpc.RpcError as e:
if grpc.StatusCode.CANCELLED != e.code():
# Not just shutting down normally
logger.error(
f"Got Error from logger channel -- shutting down: {e}")
raise e
def log(self, level: int, msg: str):
"""Log the message from the log stream.
By default, calls logger.log but this can be overridden.
Args:
level: The loglevel of the received log message
msg: The content of the message
"""
logger.log(level=level, msg=msg)
def stdstream(self, level: int, msg: str):
"""Log the stdout/stderr entry from the log stream.
By default, calls print but this can be overridden.
Args:
level: The loglevel of the received log message
msg: The content of the message
"""
print_file = sys.stderr if level == -2 else sys.stdout
print(msg, file=print_file)
def set_logstream_level(self, level: int):
logger.setLevel(level)
req = ray_client_pb2.LogSettingsRequest()
req.enabled = True
req.loglevel = level
self.request_queue.put(req)
def close(self) -> None:
self.request_queue.put(None)
if self.log_thread is not None:
self.log_thread.join()
def disable_logs(self) -> None:
req = ray_client_pb2.LogSettingsRequest()
req.enabled = False
self.request_queue.put(req)
+54
View File
@@ -0,0 +1,54 @@
from typing import Any
from typing import Dict
from typing import Optional
options = {
"num_returns": (int, lambda x: x >= 0,
"The keyword 'num_returns' only accepts 0 "
"or a positive integer"),
"num_cpus": (),
"num_gpus": (),
"resources": (),
"accelerator_type": (),
"max_calls": (int, lambda x: x >= 0,
"The keyword 'max_calls' only accepts 0 "
"or a positive integer"),
"max_restarts": (int, lambda x: x >= -1,
"The keyword 'max_restarts' only accepts -1, 0 "
"or a positive integer"),
"max_task_retries": (int, lambda x: x >= -1,
"The keyword 'max_task_retries' only accepts -1, 0 "
"or a positive integer"),
"max_retries": (int, lambda x: x >= -1,
"The keyword 'max_retries' only accepts 0, -1 "
"or a positive integer"),
"max_concurrency": (),
"name": (),
"lifetime": (),
"memory": (),
"object_store_memory": (),
"placement_group": (),
"placement_group_bundle_index": (),
"placement_group_capture_child_tasks": (),
"override_environment_variables": (),
}
def validate_options(
kwargs_dict: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if kwargs_dict is None:
return None
if len(kwargs_dict) == 0:
return None
out = {}
for k, v in kwargs_dict.items():
if k not in options.keys():
raise TypeError(f"Invalid option passed to remote(): {k}")
validator = options[k]
if len(validator) != 0:
if not isinstance(v, validator[0]):
raise ValueError(validator[2])
if not validator[1](v):
raise ValueError(validator[2])
out[k] = v
return out
@@ -0,0 +1,17 @@
from contextlib import contextmanager
import ray.experimental.client.server.server as ray_client_server
from ray.experimental.client import ray
@contextmanager
def ray_start_client_server():
ray._inside_client_test = True
server = ray_client_server.serve("localhost:50051")
ray.connect("localhost:50051")
try:
yield ray
finally:
ray._inside_client_test = False
ray.disconnect()
server.stop(0)
@@ -1,101 +0,0 @@
# Along with `api.py` this is the stub that interfaces with
# the real (C-binding, raylet) ray core.
#
# Ideally, the first import line is the only time we actually
# import ray in this library (excluding the main function for the server)
#
# While the stub is trivial, it allows us to check that the calls we're
# making into the core-ray module are contained and well-defined.
from typing import Any
from typing import Optional
from typing import Union
import ray
from ray.experimental.client.api import APIImpl
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientStub
class CoreRayAPI(APIImpl):
"""
Implements the equivalent client-side Ray API by simply passing along to
the Core Ray API. Primarily used inside of Ray Workers as a trampoline back
to core ray when passed client stubs.
"""
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
if isinstance(vals, list):
if isinstance(vals[0], ClientObjectRef):
return ray.get(
[val._unpack_ref() for val in vals], timeout=timeout)
elif isinstance(vals, ClientObjectRef):
return ray.get(vals._unpack_ref(), timeout=timeout)
return ray.get(vals, timeout=timeout)
def put(self, vals: Any, *args,
**kwargs) -> Union[ClientObjectRef, ray._raylet.ObjectRef]:
return ray.put(vals, *args, **kwargs)
def wait(self, *args, **kwargs):
return ray.wait(*args, **kwargs)
def remote(self, *args, **kwargs):
return ray.remote(*args, **kwargs)
def call_remote(self, instance: ClientStub, *args, **kwargs):
return instance._get_ray_remote_impl().remote(*args, **kwargs)
def close(self) -> None:
return None
def kill(self, actor, *, no_restart=True):
return ray.kill(actor, no_restart=no_restart)
def cancel(self, obj, *, force=False, recursive=True):
return ray.cancel(obj, force=force, recursive=recursive)
def is_initialized(self) -> bool:
return ray.is_initialized()
# Allow for generic fallback to ray.* in remote methods. This allows calls
# like ray.nodes() to be run in remote functions even though the client
# doesn't currently support them.
def __getattr__(self, key: str):
return getattr(ray, key)
class RayServerAPI(CoreRayAPI):
"""
Ray Client server-side API shim. By default, simply calls the default Core
Ray API calls, but also accepts scheduling calls from functions running
inside of other remote functions that need to create more work.
"""
def __init__(self, server_instance):
self.server = server_instance
# Wrap single item into list if needed before calling server put.
def put(self, vals: Any, *args, **kwargs) -> ClientObjectRef:
to_put = []
single = False
if isinstance(vals, list):
to_put = vals
else:
single = True
to_put.append(vals)
out = [self._put(x) for x in to_put]
if single:
out = out[0]
return out
def _put(self, val: Any):
resp = self.server._put_and_retain_obj(val)
return ClientObjectRef(resp.id)
def call_remote(self, instance: ClientStub, *args, **kwargs):
task = instance._prepare_client_task()
ticket = self.server.Schedule(task, prepared_args=args)
return ClientObjectRef(ticket.return_id)
@@ -0,0 +1,54 @@
import logging
import grpc
from typing import TYPE_CHECKING
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
logger = logging.getLogger(__name__)
class DataServicer(ray_client_pb2_grpc.RayletDataStreamerServicer):
def __init__(self, basic_service: "RayletServicer"):
self.basic_service = basic_service
def Datapath(self, request_iterator, context):
metadata = {k: v for k, v in context.invocation_metadata()}
client_id = metadata["client_id"]
if client_id == "":
logger.error("Client connecting with no client_id")
return
logger.info(f"New data connection from client {client_id}")
try:
for req in request_iterator:
resp = None
req_type = req.WhichOneof("type")
if req_type == "get":
get_resp = self.basic_service._get_object(
req.get, client_id)
resp = ray_client_pb2.DataResponse(get=get_resp)
elif req_type == "put":
put_resp = self.basic_service._put_object(
req.put, client_id)
resp = ray_client_pb2.DataResponse(put=put_resp)
elif req_type == "release":
released = []
for rel_id in req.release.ids:
rel = self.basic_service.release(client_id, rel_id)
released.append(rel)
resp = ray_client_pb2.DataResponse(
release=ray_client_pb2.ReleaseResponse(ok=released))
else:
raise Exception(f"Unreachable code: Request type "
f"{req_type} not handled in Datapath")
resp.req_id = req.req_id
yield resp
except grpc.RpcError as e:
logger.debug(f"Closing data channel: {e}")
finally:
logger.info(f"Lost data connection from client {client_id}")
self.basic_service.release_all(client_id)
@@ -0,0 +1,101 @@
"""This file responds to log stream requests and forwards logs
with its handler.
"""
import io
import threading
import queue
import logging
import grpc
import uuid
from ray.worker import print_worker_logs
from ray.ray_logging import global_worker_stdstream_dispatcher
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
logger = logging.getLogger(__name__)
class LogstreamHandler(logging.Handler):
def __init__(self, queue, level):
super().__init__()
self.queue = queue
self.level = level
def emit(self, record: logging.LogRecord):
logdata = ray_client_pb2.LogData()
logdata.msg = record.getMessage()
logdata.level = record.levelno
logdata.name = record.name
self.queue.put(logdata)
class StdStreamHandler:
def __init__(self, queue):
self.queue = queue
self.id = str(uuid.uuid4())
def handle(self, data):
logdata = ray_client_pb2.LogData()
logdata.level = -2 if data["is_err"] else -1
logdata.name = "stderr" if data["is_err"] else "stdout"
with io.StringIO() as file:
print_worker_logs(data, file)
logdata.msg = file.getvalue()
self.queue.put(logdata)
def register_global(self):
global_worker_stdstream_dispatcher.add_handler(self.id, self.handle)
def unregister_global(self):
global_worker_stdstream_dispatcher.remove_handler(self.id)
def log_status_change_thread(log_queue, request_iterator):
std_handler = StdStreamHandler(log_queue)
current_handler = None
root_logger = logging.getLogger("ray")
default_level = root_logger.getEffectiveLevel()
try:
for req in request_iterator:
if current_handler is not None:
root_logger.setLevel(default_level)
root_logger.removeHandler(current_handler)
std_handler.unregister_global()
if not req.enabled:
current_handler = None
continue
current_handler = LogstreamHandler(log_queue, req.loglevel)
std_handler.register_global()
root_logger.addHandler(current_handler)
root_logger.setLevel(req.loglevel)
except grpc.RpcError as e:
logger.debug(f"closing log thread "
f"grpc error reading request_iterator: {e}")
finally:
if current_handler is not None:
root_logger.setLevel(default_level)
root_logger.removeHandler(current_handler)
std_handler.unregister_global()
log_queue.put(None)
class LogstreamServicer(ray_client_pb2_grpc.RayletLogStreamerServicer):
def Logstream(self, request_iterator, context):
logger.info("New logs connection")
log_queue = queue.Queue()
thread = threading.Thread(
target=log_status_change_thread,
args=(log_queue, request_iterator),
daemon=True)
thread.start()
try:
queue_iter = iter(log_queue.get, None)
for record in queue_iter:
if record is None:
break
yield record
except grpc.RpcError as e:
logger.debug(f"Closing log channel: {e}")
finally:
thread.join()
+305 -150
View File
@@ -1,6 +1,14 @@
import logging
from concurrent import futures
import grpc
import base64
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import Set
from typing import Optional
from ray import cloudpickle
import ray
import ray.state
@@ -9,29 +17,34 @@ import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
import time
import inspect
import json
from ray.experimental.client import stash_api_for_tests, _set_server_api
from ray.experimental.client.common import convert_from_arg
from ray.experimental.client.common import encode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.server.core_ray_api import RayServerAPI
from ray.experimental.client.server.server_pickler import convert_from_arg
from ray.experimental.client.server.server_pickler import dumps_from_server
from ray.experimental.client.server.server_pickler import loads_from_client
from ray.experimental.client.server.dataservicer import DataServicer
from ray.experimental.client.server.logservicer import LogstreamServicer
from ray.experimental.client.server.server_stubs import current_remote
from ray._private.client_mode_hook import disable_client_hook
logger = logging.getLogger(__name__)
class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
def __init__(self, test_mode=False):
self.object_refs = {}
def __init__(self):
self.object_refs: Dict[str, Dict[bytes, ray.ObjectRef]] = defaultdict(
dict)
self.function_refs = {}
self.actor_refs = {}
self.actor_refs: Dict[bytes, ray.ActorHandle] = {}
self.actor_owners: Dict[str, Set[bytes]] = defaultdict(set)
self.registered_actor_classes = {}
self._test_mode = test_mode
self._current_function_stub = None
def ClusterInfo(self, request,
context=None) -> ray_client_pb2.ClusterInfoResponse:
resp = ray_client_pb2.ClusterInfoResponse()
resp.type = request.type
if request.type == ray_client_pb2.ClusterInfoType.CLUSTER_RESOURCES:
resources = ray.cluster_resources()
with disable_client_hook():
resources = ray.cluster_resources()
# Normalize resources into floats
# (the function may return values that are ints)
float_resources = {k: float(v) for k, v in resources.items()}
@@ -40,7 +53,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
table=float_resources))
elif request.type == \
ray_client_pb2.ClusterInfoType.AVAILABLE_RESOURCES:
resources = ray.available_resources()
with disable_client_hook():
resources = ray.available_resources()
# Normalize resources into floats
# (the function may return values that are ints)
float_resources = {k: float(v) for k, v in resources.items()}
@@ -48,7 +62,8 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
ray_client_pb2.ClusterInfoResponse.ResourceTable(
table=float_resources))
else:
resp.json = self._return_debug_cluster_info(request, context)
with disable_client_hook():
resp.json = self._return_debug_cluster_info(request, context)
return resp
def _return_debug_cluster_info(self, request, context=None) -> str:
@@ -61,20 +76,61 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
raise TypeError("Unsupported cluster info type")
return json.dumps(data)
def Terminate(self, request, context=None):
if request.WhichOneof("terminate_type") == "task_object":
def release(self, client_id: str, id: bytes) -> bool:
if client_id in self.object_refs:
if id in self.object_refs[client_id]:
logger.debug(f"Releasing object {id.hex()} for {client_id}")
del self.object_refs[client_id][id]
return True
if client_id in self.actor_owners:
if id in self.actor_owners[client_id]:
logger.debug(f"Releasing actor {id.hex()} for {client_id}")
del self.actor_refs[id]
self.actor_owners[client_id].remove(id)
return True
return False
def release_all(self, client_id):
self._release_objects(client_id)
self._release_actors(client_id)
def _release_objects(self, client_id):
if client_id not in self.object_refs:
logger.debug(f"Releasing client with no references: {client_id}")
return
count = len(self.object_refs[client_id])
del self.object_refs[client_id]
logger.debug(f"Released all {count} objects for client {client_id}")
def _release_actors(self, client_id):
if client_id not in self.actor_owners:
logger.debug(f"Releasing client with no actors: {client_id}")
count = 0
for id_bytes in self.actor_owners[client_id]:
count += 1
del self.actor_refs[id_bytes]
del self.actor_owners[client_id]
logger.debug(f"Released all {count} actors for client: {client_id}")
def Terminate(self, req, context=None):
if req.WhichOneof("terminate_type") == "task_object":
try:
object_ref = cloudpickle.loads(request.task_object.handle)
ray.cancel(
object_ref,
force=request.task_object.force,
recursive=request.task_object.recursive)
object_ref = \
self.object_refs[req.client_id][req.task_object.id]
with disable_client_hook():
ray.cancel(
object_ref,
force=req.task_object.force,
recursive=req.task_object.recursive)
except Exception as e:
return_exception_in_context(e, context)
elif request.WhichOneof("terminate_type") == "actor":
elif req.WhichOneof("terminate_type") == "actor":
try:
actor_ref = cloudpickle.loads(request.actor.handle)
ray.kill(actor_ref, no_restart=request.actor.no_restart)
actor_ref = self.actor_refs[req.actor.id]
with disable_client_hook():
ray.kill(actor_ref, no_restart=req.actor.no_restart)
except Exception as e:
return_exception_in_context(e, context)
else:
@@ -84,166 +140,221 @@ class RayletServicer(ray_client_pb2_grpc.RayletDriverServicer):
return ray_client_pb2.TerminateResponse(ok=True)
def GetObject(self, request, context=None):
request_ref = cloudpickle.loads(request.handle)
if request_ref.binary() not in self.object_refs:
return self._get_object(request, "", context)
def _get_object(self, request, client_id: str, context=None):
if request.id not in self.object_refs[client_id]:
return ray_client_pb2.GetResponse(valid=False)
objectref = self.object_refs[request_ref.binary()]
logger.info("get: %s" % objectref)
objectref = self.object_refs[client_id][request.id]
logger.debug("get: %s" % objectref)
try:
item = ray.get(objectref, timeout=request.timeout)
with disable_client_hook():
item = ray.get(objectref, timeout=request.timeout)
except Exception as e:
return_exception_in_context(e, context)
item_ser = cloudpickle.dumps(item)
return ray_client_pb2.GetResponse(
valid=False, error=cloudpickle.dumps(e))
item_ser = dumps_from_server(item, client_id, self)
return ray_client_pb2.GetResponse(valid=True, data=item_ser)
def PutObject(self, request, context=None) -> ray_client_pb2.PutResponse:
obj = cloudpickle.loads(request.data)
objectref = self._put_and_retain_obj(obj)
pickled_ref = cloudpickle.dumps(objectref)
return ray_client_pb2.PutResponse(
ref=make_remote_ref(objectref.binary(), pickled_ref))
def PutObject(self, request: ray_client_pb2.PutRequest,
context=None) -> ray_client_pb2.PutResponse:
"""gRPC entrypoint for unary PutObject
"""
return self._put_object(request, "", context)
def _put_and_retain_obj(self, obj) -> ray.ObjectRef:
objectref = ray.put(obj)
self.object_refs[objectref.binary()] = objectref
logger.info("put: %s" % objectref)
return objectref
def _put_object(self,
request: ray_client_pb2.PutRequest,
client_id: str,
context=None):
"""Put an object in the cluster with ray.put() via gRPC.
Args:
request: PutRequest with pickled data.
client_id: The client who owns this data, for tracking when to
delete this reference.
context: gRPC context.
"""
obj = loads_from_client(request.data, self)
with disable_client_hook():
objectref = ray.put(obj)
self.object_refs[client_id][objectref.binary()] = objectref
logger.debug("put: %s" % objectref)
return ray_client_pb2.PutResponse(id=objectref.binary())
def WaitObject(self, request, context=None) -> ray_client_pb2.WaitResponse:
object_refs = [cloudpickle.loads(o) for o in request.object_handles]
object_refs = []
for id in request.object_ids:
if id not in self.object_refs[request.client_id]:
raise Exception(
"Asking for a ref not associated with this client: %s" %
str(id))
object_refs.append(self.object_refs[request.client_id][id])
num_returns = request.num_returns
timeout = request.timeout
object_refs_ids = []
for object_ref in object_refs:
if object_ref.binary() not in self.object_refs:
return ray_client_pb2.WaitResponse(valid=False)
object_refs_ids.append(self.object_refs[object_ref.binary()])
try:
ready_object_refs, remaining_object_refs = ray.wait(
object_refs_ids,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None)
except Exception:
with disable_client_hook():
ready_object_refs, remaining_object_refs = ray.wait(
object_refs,
num_returns=num_returns,
timeout=timeout if timeout != -1 else None,
)
except Exception as e:
# TODO(ameer): improve exception messages.
logger.error(f"Exception {e}")
return ray_client_pb2.WaitResponse(valid=False)
logger.info("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
logger.debug("wait: %s %s" % (str(ready_object_refs),
str(remaining_object_refs)))
ready_object_ids = [
make_remote_ref(
id=ready_object_ref.binary(),
handle=cloudpickle.dumps(ready_object_ref),
) for ready_object_ref in ready_object_refs
ready_object_ref.binary() for ready_object_ref in ready_object_refs
]
remaining_object_ids = [
make_remote_ref(
id=remaining_object_ref.binary(),
handle=cloudpickle.dumps(remaining_object_ref),
) for remaining_object_ref in remaining_object_refs
remaining_object_ref.binary()
for remaining_object_ref in remaining_object_refs
]
return ray_client_pb2.WaitResponse(
valid=True,
ready_object_ids=ready_object_ids,
remaining_object_ids=remaining_object_ids)
def Schedule(self, task, context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
logger.info("schedule: %s %s" %
(task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type)))
if task.type == ray_client_pb2.ClientTask.FUNCTION:
return self._schedule_function(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
return self._schedule_actor(task, context, prepared_args)
elif task.type == ray_client_pb2.ClientTask.METHOD:
return self._schedule_method(task, context, prepared_args)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(task.type))
def Schedule(self, task, context=None) -> ray_client_pb2.ClientTaskTicket:
logger.debug(
"schedule: %s %s" % (task.name,
ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type)))
try:
with disable_client_hook():
if task.type == ray_client_pb2.ClientTask.FUNCTION:
result = self._schedule_function(task, context)
elif task.type == ray_client_pb2.ClientTask.ACTOR:
result = self._schedule_actor(task, context)
elif task.type == ray_client_pb2.ClientTask.METHOD:
result = self._schedule_method(task, context)
elif task.type == ray_client_pb2.ClientTask.NAMED_ACTOR:
result = self._schedule_named_actor(task, context)
else:
raise NotImplementedError(
"Unimplemented Schedule task type: %s" %
ray_client_pb2.ClientTask.RemoteExecType.Name(
task.type))
result.valid = True
return result
except Exception as e:
logger.error(f"Caught schedule exception {e}")
raise e
return ray_client_pb2.ClientTaskTicket(
valid=False, error=cloudpickle.dumps(e))
def _schedule_method(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
def _schedule_method(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
actor_handle = self.actor_refs.get(task.payload_id)
if actor_handle is None:
raise Exception(
"Can't run an actor the server doesn't have a handle for")
arglist = _convert_args(task.args, prepared_args)
with stash_api_for_tests(self._test_mode):
output = getattr(actor_handle, task.name).remote(*arglist)
self.object_refs[output.binary()] = output
pickled_ref = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_ref))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
method = getattr(actor_handle, task.name)
opts = decode_options(task.options)
if opts is not None:
method = method.options(**opts)
output = method.remote(*arglist, **kwargs)
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _schedule_actor(self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
with stash_api_for_tests(self._test_mode):
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.registered_actor_classes:
actor_class_ref = self.object_refs[payload_ref.binary()]
def _schedule_actor(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_class = self.lookup_or_register_actor(
task.payload_id, task.client_id,
decode_options(task.baseline_options))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_class = remote_class.options(**opts)
with current_remote(remote_class):
actor = remote_class.remote(*arglist, **kwargs)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_ids=[actor._actor_id.binary()])
def _schedule_function(self, task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
remote_func = self.lookup_or_register_func(
task.payload_id, task.client_id,
decode_options(task.baseline_options))
arglist, kwargs = self._convert_args(task.args, task.kwargs)
opts = decode_options(task.options)
if opts is not None:
remote_func = remote_func.options(**opts)
with current_remote(remote_func):
output = remote_func.remote(*arglist, **kwargs)
ids = self.unify_and_track_outputs(output, task.client_id)
return ray_client_pb2.ClientTaskTicket(return_ids=ids)
def _schedule_named_actor(self,
task: ray_client_pb2.ClientTask,
context=None) -> ray_client_pb2.ClientTaskTicket:
assert len(task.payload_id) == 0
actor = ray.get_actor(task.name)
self.actor_refs[actor._actor_id.binary()] = actor
self.actor_owners[task.client_id].add(actor._actor_id.binary())
return ray_client_pb2.ClientTaskTicket(
return_ids=[actor._actor_id.binary()])
def _convert_args(self, arg_list, kwarg_map):
argout = []
for arg in arg_list:
t = convert_from_arg(arg, self)
argout.append(t)
kwargout = {}
for k in kwarg_map:
kwargout[k] = convert_from_arg(kwarg_map[k], self)
return argout, kwargout
def lookup_or_register_func(
self, id: bytes, client_id: str,
options: Optional[Dict]) -> ray.remote_function.RemoteFunction:
with disable_client_hook():
if id not in self.function_refs:
funcref = self.object_refs[client_id][id]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to register function that "
"isn't a function.")
if options is None or len(options) == 0:
self.function_refs[id] = ray.remote(func)
else:
self.function_refs[id] = ray.remote(**options)(func)
return self.function_refs[id]
def lookup_or_register_actor(self, id: bytes, client_id: str,
options: Optional[Dict]):
with disable_client_hook():
if id not in self.registered_actor_classes:
actor_class_ref = self.object_refs[client_id][id]
actor_class = ray.get(actor_class_ref)
if not inspect.isclass(actor_class):
raise Exception("Attempting to schedule actor that "
"isn't a class.")
reg_class = ray.remote(actor_class)
self.registered_actor_classes[payload_ref.binary()] = reg_class
remote_class = self.registered_actor_classes[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
actor = remote_class.remote(*arglist)
actorhandle = cloudpickle.dumps(actor)
self.actor_refs[actorhandle] = actor
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(actor._actor_id.binary(), actorhandle))
if options is None or len(options) == 0:
reg_class = ray.remote(actor_class)
else:
reg_class = ray.remote(**options)(actor_class)
self.registered_actor_classes[id] = reg_class
def _schedule_function(
self,
task: ray_client_pb2.ClientTask,
context=None,
prepared_args=None) -> ray_client_pb2.ClientTaskTicket:
payload_ref = cloudpickle.loads(task.payload_id)
if payload_ref.binary() not in self.function_refs:
funcref = self.object_refs[payload_ref.binary()]
func = ray.get(funcref)
if not inspect.isfunction(func):
raise Exception("Attempting to schedule function that "
"isn't a function.")
self.function_refs[payload_ref.binary()] = ray.remote(func)
remote_func = self.function_refs[payload_ref.binary()]
arglist = _convert_args(task.args, prepared_args)
# Prepare call if we're in a test
with stash_api_for_tests(self._test_mode):
output = remote_func.remote(*arglist)
if output.binary() in self.object_refs:
raise Exception("already found it")
self.object_refs[output.binary()] = output
pickled_output = cloudpickle.dumps(output)
return ray_client_pb2.ClientTaskTicket(
return_ref=make_remote_ref(output.binary(), pickled_output))
return self.registered_actor_classes[id]
def _convert_args(arg_list, prepared_args=None):
if prepared_args is not None:
return prepared_args
out = []
for arg in arg_list:
t = convert_from_arg(arg)
if isinstance(t, ClientObjectRef):
out.append(t._unpack_ref())
def unify_and_track_outputs(self, output, client_id):
if output is None:
outputs = []
elif isinstance(output, list):
outputs = output
else:
out.append(t)
return out
def make_remote_ref(id: bytes, handle: bytes) -> ray_client_pb2.RemoteRef:
return ray_client_pb2.RemoteRef(
id=id,
handle=handle,
)
outputs = [output]
for out in outputs:
if out.binary() in self.object_refs[client_id]:
logger.warning(f"Already saw object_ref {out}")
self.object_refs[client_id][out.binary()] = out
return [out.binary() for out in outputs]
def return_exception_in_context(err, context):
@@ -252,17 +363,61 @@ def return_exception_in_context(err, context):
context.set_code(grpc.StatusCode.INTERNAL)
def serve(connection_str, test_mode=False):
def encode_exception(exception) -> str:
data = cloudpickle.dumps(exception)
return base64.standard_b64encode(data).decode()
def decode_options(
options: ray_client_pb2.TaskOptions) -> Optional[Dict[str, Any]]:
if options.json_options == "":
return None
opts = json.loads(options.json_options)
assert isinstance(opts, dict)
return opts
_current_servicer: Optional[RayletServicer] = None
# Used by tests to peek inside the servicer
def _get_current_servicer():
global _current_servicer
return _current_servicer
def serve(connection_str):
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
task_servicer = RayletServicer(test_mode=test_mode)
_set_server_api(RayServerAPI(task_servicer))
task_servicer = RayletServicer()
data_servicer = DataServicer(task_servicer)
logs_servicer = LogstreamServicer()
global _current_servicer
_current_servicer = task_servicer
ray_client_pb2_grpc.add_RayletDriverServicer_to_server(
task_servicer, server)
ray_client_pb2_grpc.add_RayletDataStreamerServicer_to_server(
data_servicer, server)
ray_client_pb2_grpc.add_RayletLogStreamerServicer_to_server(
logs_servicer, server)
server.add_insecure_port(connection_str)
server.start()
return server
def init_and_serve(connection_str, *args, **kwargs):
with disable_client_hook():
# Disable client mode inside the worker's environment
info = ray.init(*args, **kwargs)
server = serve(connection_str)
return (server, info)
def shutdown_with_server(server, _exiting_interpreter=False):
server.stop(1)
with disable_client_hook():
ray.shutdown(_exiting_interpreter)
if __name__ == "__main__":
logging.basicConfig(level="INFO")
# TODO(barakmich): Perhaps wrap ray init
@@ -0,0 +1,135 @@
"""Implements the client side of the client/server pickling protocol.
These picklers are aware of the server internals and can find the
references held for the client within the server.
More discussion about the client/server pickling protocol can be found in:
ray/experimental/client/client_pickler.py
ServerPickler dumps ray objects from the server into the appropriate stubs.
ClientUnpickler loads stubs from the client and finds their associated handle
in the server instance.
"""
import cloudpickle
import io
import sys
import ray
from typing import Any
from typing import TYPE_CHECKING
from ray._private.client_mode_hook import disable_client_hook
from ray.experimental.client.client_pickler import PickleStub
from ray.experimental.client.server.server_stubs import (
ServerSelfReferenceSentinel)
if TYPE_CHECKING:
from ray.experimental.client.server.server import RayletServicer
import ray.core.generated.ray_client_pb2 as ray_client_pb2
if sys.version_info < (3, 8):
try:
import pickle5 as pickle # noqa: F401
except ImportError:
import pickle # noqa: F401
else:
import pickle # noqa: F401
class ServerPickler(cloudpickle.CloudPickler):
def __init__(self, client_id: str, server: "RayletServicer", *args,
**kwargs):
super().__init__(*args, **kwargs)
self.client_id = client_id
self.server = server
def persistent_id(self, obj):
if isinstance(obj, ray.ObjectRef):
obj_id = obj.binary()
if obj_id not in self.server.object_refs[self.client_id]:
# We're passing back a reference, probably inside a reference.
# Let's hold onto it.
self.server.object_refs[self.client_id][obj_id] = obj
return PickleStub(
type="Object",
client_id=self.client_id,
ref_id=obj_id,
name=None,
baseline_options=None,
)
elif isinstance(obj, ray.actor.ActorHandle):
actor_id = obj._actor_id.binary()
if actor_id not in self.server.actor_refs:
# We're passing back a handle, probably inside a reference.
self.actor_refs[actor_id] = obj
if actor_id not in self.actor_owners[self.client_id]:
self.actor_owners[self.client_id].add(actor_id)
return PickleStub(
type="Actor",
client_id=self.client_id,
ref_id=obj._actor_id.binary(),
name=None,
baseline_options=None,
)
return None
class ClientUnpickler(pickle.Unpickler):
def __init__(self, server, *args, **kwargs):
super().__init__(*args, **kwargs)
self.server = server
def persistent_load(self, pid):
assert isinstance(pid, PickleStub)
if pid.type == "Ray":
return ray
elif pid.type == "Object":
return self.server.object_refs[pid.client_id][pid.ref_id]
elif pid.type == "Actor":
return self.server.actor_refs[pid.ref_id]
elif pid.type == "RemoteFuncSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteFunc":
return self.server.lookup_or_register_func(
pid.ref_id, pid.client_id, pid.baseline_options)
elif pid.type == "RemoteActorSelfReference":
return ServerSelfReferenceSentinel()
elif pid.type == "RemoteActor":
return self.server.lookup_or_register_actor(
pid.ref_id, pid.client_id, pid.baseline_options)
elif pid.type == "RemoteMethod":
actor = self.server.actor_refs[pid.ref_id]
return getattr(actor, pid.name)
else:
raise NotImplementedError("Uncovered client data type")
def dumps_from_server(obj: Any,
client_id: str,
server_instance: "RayletServicer",
protocol=None) -> bytes:
with io.BytesIO() as file:
sp = ServerPickler(client_id, server_instance, file, protocol=protocol)
sp.dump(obj)
return file.getvalue()
def loads_from_client(data: bytes,
server_instance: "RayletServicer",
*,
fix_imports=True,
encoding="ASCII",
errors="strict") -> Any:
with disable_client_hook():
if isinstance(data, str):
raise TypeError("Can't load pickle from unicode string")
file = io.BytesIO(data)
return ClientUnpickler(
server_instance, file, fix_imports=fix_imports,
encoding=encoding).load()
def convert_from_arg(pb: "ray_client_pb2.Arg",
server: "RayletServicer") -> Any:
return loads_from_client(pb.data, server)
@@ -0,0 +1,29 @@
from contextlib import contextmanager
_current_remote_obj = None
@contextmanager
def current_remote(r):
global _current_remote_obj
remote = _current_remote_obj
_current_remote_obj = r
try:
yield
finally:
_current_remote_obj = remote
class ServerSelfReferenceSentinel:
def __init__(self):
pass
def __reduce__(self):
global _current_remote_obj
if _current_remote_obj is None:
return (ServerSelfReferenceSentinel, tuple())
return (identity, (_current_remote_obj, ))
def identity(x):
return x
+124 -63
View File
@@ -2,27 +2,30 @@
It implements the Ray API functions that are forwarded through grpc calls
to the server.
"""
import inspect
import base64
import json
import logging
import uuid
from collections import defaultdict
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Optional
import ray.cloudpickle as cloudpickle
from ray.util.inspect import is_cython
import grpc
from ray.exceptions import TaskCancelledError
import ray.cloudpickle as cloudpickle
import ray.core.generated.ray_client_pb2 as ray_client_pb2
import ray.core.generated.ray_client_pb2_grpc as ray_client_pb2_grpc
from ray.experimental.client.common import convert_to_arg
from ray.experimental.client.common import decode_exception
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.common import ClientActorClass
from ray.experimental.client.client_pickler import convert_to_arg
from ray.experimental.client.client_pickler import dumps_from_client
from ray.experimental.client.client_pickler import loads_from_server
from ray.experimental.client.common import ClientActorHandle
from ray.experimental.client.common import ClientRemoteFunc
from ray.experimental.client.common import ClientActorRef
from ray.experimental.client.common import ClientObjectRef
from ray.experimental.client.dataclient import DataClient
from ray.experimental.client.logsclient import LogstreamClient
logger = logging.getLogger(__name__)
@@ -31,34 +34,37 @@ class Worker:
def __init__(self,
conn_str: str = "",
secure: bool = False,
metadata: List[Tuple[str, str]] = None,
stub=None):
metadata: List[Tuple[str, str]] = None):
"""Initializes the worker side grpc client.
Args:
stub: custom grpc stub.
secure: whether to use SSL secure channel or not.
metadata: additional metadata passed in the grpc request headers.
"""
self.metadata = metadata
self.channel = None
if stub is None:
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self._client_id = make_client_id()
if secure:
credentials = grpc.ssl_channel_credentials()
self.channel = grpc.secure_channel(conn_str, credentials)
else:
self.server = stub
self.channel = grpc.insecure_channel(conn_str)
self.server = ray_client_pb2_grpc.RayletDriverStub(self.channel)
self.data_client = DataClient(self.channel, self._client_id)
self.reference_count: Dict[bytes, int] = defaultdict(int)
self.log_client = LogstreamClient(self.channel)
self.log_client.set_logstream_level(logging.INFO)
self.closed = False
def get(self, vals, *, timeout: Optional[float] = None) -> Any:
to_get = []
single = False
if isinstance(vals, list):
to_get = [x.handle for x in vals]
to_get = vals
elif isinstance(vals, ClientObjectRef):
to_get = [vals.handle]
to_get = [vals]
single = True
else:
raise Exception("Can't get something that's not a "
@@ -70,15 +76,17 @@ class Worker:
out = out[0]
return out
def _get(self, handle: bytes, timeout: float):
req = ray_client_pb2.GetRequest(handle=handle, timeout=timeout)
def _get(self, ref: ClientObjectRef, timeout: float):
req = ray_client_pb2.GetRequest(id=ref.id, timeout=timeout)
try:
data = self.server.GetObject(req, metadata=self.metadata)
data = self.data_client.GetObject(req)
except grpc.RpcError as e:
raise decode_exception(e.details())
raise e.details()
if not data.valid:
raise TaskCancelledError(handle)
return cloudpickle.loads(data.data)
err = cloudpickle.loads(data.error)
logger.error(err)
raise err
return loads_from_server(data.data)
def put(self, vals):
to_put = []
@@ -95,26 +103,37 @@ class Worker:
return out
def _put(self, val):
data = cloudpickle.dumps(val)
if isinstance(val, ClientObjectRef):
raise TypeError(
"Calling 'put' on an ObjectRef is not allowed "
"(similarly, returning an ObjectRef from a remote "
"function is not allowed). If you really want to "
"do this, you can wrap the ObjectRef in a list and "
"call 'put' on it (or return it).")
data = dumps_from_client(val, self._client_id)
req = ray_client_pb2.PutRequest(data=data)
resp = self.server.PutObject(req, metadata=self.metadata)
return ClientObjectRef.from_remote_ref(resp.ref)
resp = self.data_client.PutObject(req)
return ClientObjectRef(resp.id)
def wait(self,
object_refs: List[ClientObjectRef],
*,
num_returns: int = 1,
timeout: float = None
timeout: float = None,
fetch_local: bool = True
) -> Tuple[List[ClientObjectRef], List[ClientObjectRef]]:
assert isinstance(object_refs, list)
if not isinstance(object_refs, list):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got {type(object_refs)}")
for ref in object_refs:
assert isinstance(ref, ClientObjectRef)
if not isinstance(ref, ClientObjectRef):
raise TypeError("wait() expected a list of ClientObjectRef, "
f"got list containing {type(ref)}")
data = {
"object_handles": [
object_ref.handle for object_ref in object_refs
],
"object_ids": [object_ref.id for object_ref in object_refs],
"num_returns": num_returns,
"timeout": timeout if timeout else -1
"timeout": timeout if timeout else -1,
"client_id": self._client_id,
}
req = ray_client_pb2.WaitRequest(**data)
resp = self.server.WaitObject(req, metadata=self.metadata)
@@ -122,41 +141,69 @@ class Worker:
# TODO(ameer): improve error/exceptions messages.
raise Exception("Client Wait request failed. Reference invalid?")
client_ready_object_ids = [
ClientObjectRef.from_remote_ref(ref)
for ref in resp.ready_object_ids
ClientObjectRef(ref) for ref in resp.ready_object_ids
]
client_remaining_object_ids = [
ClientObjectRef.from_remote_ref(ref)
for ref in resp.remaining_object_ids
ClientObjectRef(ref) for ref in resp.remaining_object_ids
]
return (client_ready_object_ids, client_remaining_object_ids)
def remote(self, function_or_class, *args, **kwargs):
# TODO(barakmich): Arguments to ray.remote
# get captured here.
if (inspect.isfunction(function_or_class)
or is_cython(function_or_class)):
return ClientRemoteFunc(function_or_class)
elif inspect.isclass(function_or_class):
return ClientActorClass(function_or_class)
else:
raise TypeError("The @ray.remote decorator must be applied to "
"either a function or to a class.")
def call_remote(self, instance, *args, **kwargs):
def call_remote(self, instance, *args, **kwargs) -> List[bytes]:
task = instance._prepare_client_task()
for arg in args:
pb_arg = convert_to_arg(arg)
pb_arg = convert_to_arg(arg, self._client_id)
task.args.append(pb_arg)
logging.debug("Scheduling %s" % task)
ticket = self.server.Schedule(task, metadata=self.metadata)
return ClientObjectRef.from_remote_ref(ticket.return_ref)
for k, v in kwargs.items():
task.kwargs[k].CopyFrom(convert_to_arg(v, self._client_id))
return self._call_schedule_for_task(task)
def _call_schedule_for_task(
self, task: ray_client_pb2.ClientTask) -> List[bytes]:
logger.debug("Scheduling %s" % task)
task.client_id = self._client_id
try:
ticket = self.server.Schedule(task, metadata=self.metadata)
except grpc.RpcError as e:
raise decode_exception(e.details)
if not ticket.valid:
raise cloudpickle.loads(ticket.error)
return ticket.return_ids
def call_release(self, id: bytes) -> None:
if self.closed:
return
self.reference_count[id] -= 1
if self.reference_count[id] == 0:
self._release_server(id)
del self.reference_count[id]
def _release_server(self, id: bytes) -> None:
if self.data_client is not None:
logger.debug(f"Releasing {id}")
self.data_client.ReleaseObject(
ray_client_pb2.ReleaseRequest(ids=[id]))
def call_retain(self, id: bytes) -> None:
logger.debug(f"Retaining {id.hex()}")
self.reference_count[id] += 1
def close(self):
self.server = None
self.log_client.close()
self.data_client.close()
if self.channel:
self.channel.close()
self.channel = None
self.server = None
self.closed = True
def get_actor(self, name: str) -> ClientActorHandle:
task = ray_client_pb2.ClientTask()
task.type = ray_client_pb2.ClientTask.NAMED_ACTOR
task.name = name
ids = self._call_schedule_for_task(task)
assert len(ids) == 1
return ClientActorHandle(ClientActorRef(ids[0]))
def terminate_actor(self, actor: ClientActorHandle,
no_restart: bool) -> None:
@@ -164,10 +211,11 @@ class Worker:
raise ValueError("ray.kill() only supported for actors. "
"Got: {}.".format(type(actor)))
term_actor = ray_client_pb2.TerminateRequest.ActorTerminate()
term_actor.handle = actor.actor_ref.handle
term_actor.id = actor.actor_ref.id
term_actor.no_restart = no_restart
try:
term = ray_client_pb2.TerminateRequest(actor=term_actor)
term.client_id = self._client_id
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
@@ -179,11 +227,12 @@ class Worker:
"ray.cancel() only supported for non-actor object refs. "
f"Got: {type(obj)}.")
term_object = ray_client_pb2.TerminateRequest.TaskObjectTerminate()
term_object.handle = obj.handle
term_object.id = obj.id
term_object.force = force
term_object.recursive = recursive
try:
term = ray_client_pb2.TerminateRequest(task_object=term_object)
term.client_id = self._client_id
self.server.Terminate(term)
except grpc.RpcError as e:
raise decode_exception(e.details())
@@ -193,7 +242,9 @@ class Worker:
req.type = type
resp = self.server.ClusterInfo(req)
if resp.WhichOneof("response_type") == "resource_table":
return resp.resource_table.table
# translate from a proto map to a python dict
output_dict = {k: v for k, v in resp.resource_table.table.items()}
return output_dict
return json.loads(resp.json)
def is_initialized(self) -> bool:
@@ -201,3 +252,13 @@ class Worker:
return self.get_cluster_info(
ray_client_pb2.ClusterInfoType.IS_INITIALIZED)
return False
def make_client_id() -> str:
id = uuid.uuid4()
return id.hex
def decode_exception(data) -> Exception:
data = base64.standard_b64decode(data)
return loads_from_server(data)
@@ -1,18 +0,0 @@
import ray
def force_spill_objects(object_refs):
"""Force spilling objects to external storage.
Args:
object_refs: Object refs of the objects to be
spilled.
"""
core_worker = ray.worker.global_worker.core_worker
# Make sure that the values are object refs.
for object_ref in object_refs:
if not isinstance(object_ref, ray.ObjectRef):
raise TypeError(
f"Attempting to call `force_spill_objects` on the "
f"value {object_ref}, which is not an ray.ObjectRef.")
return core_worker.force_spill_objects(object_refs)
+12 -3
View File
@@ -157,12 +157,15 @@ class ExternalStorage(metaclass=abc.ABCMeta):
@abc.abstractmethod
def restore_spilled_objects(self, object_refs: List[ObjectRef],
url_with_offset_list: List[str]):
url_with_offset_list: List[str]) -> int:
"""Restore objects from the external storage.
Args:
object_refs: List of object IDs (note that it is not ref).
url_with_offset_list: List of url_with_offset.
Returns:
The total number of bytes restored.
"""
@abc.abstractmethod
@@ -215,6 +218,7 @@ class FileSystemStorage(ExternalStorage):
def restore_spilled_objects(self, object_refs: List[ObjectRef],
url_with_offset_list: List[str]):
total = 0
for i in range(len(object_refs)):
object_ref = object_refs[i]
url_with_offset = url_with_offset_list[i].decode()
@@ -228,9 +232,11 @@ class FileSystemStorage(ExternalStorage):
metadata_len = int.from_bytes(f.read(8), byteorder="little")
buf_len = int.from_bytes(f.read(8), byteorder="little")
self._size_check(metadata_len, buf_len, parsed_result.size)
total += buf_len
metadata = f.read(metadata_len)
# read remaining data to our buffer
self._put_object_to_store(metadata, buf_len, f, object_ref)
return total
def delete_spilled_objects(self, urls: List[str]):
for url in urls:
@@ -297,6 +303,7 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
def restore_spilled_objects(self, object_refs: List[ObjectRef],
url_with_offset_list: List[str]):
from smart_open import open
total = 0
for i in range(len(object_refs)):
object_ref = object_refs[i]
url_with_offset = url_with_offset_list[i].decode()
@@ -315,9 +322,11 @@ class ExternalStorageSmartOpenImpl(ExternalStorage):
metadata_len = int.from_bytes(f.read(8), byteorder="little")
buf_len = int.from_bytes(f.read(8), byteorder="little")
self._size_check(metadata_len, buf_len, parsed_result.size)
total += buf_len
metadata = f.read(metadata_len)
# read remaining data to our buffer
self._put_object_to_store(metadata, buf_len, f, object_ref)
return total
def delete_spilled_objects(self, urls: List[str]):
pass
@@ -367,8 +376,8 @@ def restore_spilled_objects(object_refs: List[ObjectRef],
object_refs: List of object IDs (note that it is not ref).
url_with_offset_list: List of url_with_offset.
"""
_external_storage.restore_spilled_objects(object_refs,
url_with_offset_list)
return _external_storage.restore_spilled_objects(object_refs,
url_with_offset_list)
def delete_spilled_objects(urls: List[str]):
+9 -4
View File
@@ -12,6 +12,7 @@ import hashlib
import cython
import inspect
import uuid
import ray.ray_constants as ray_constants
ctypedef object (*FunctionDescriptor_from_cpp)(const CFunctionDescriptor &)
@@ -188,7 +189,8 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
function_name = function.__name__
class_name = ""
pickled_function_hash = hashlib.sha1(pickled_function).hexdigest()
pickled_function_hash = hashlib.shake_128(pickled_function).hexdigest(
ray_constants.ID_SIZE)
return cls(module_name, function_name, class_name,
pickled_function_hash)
@@ -208,7 +210,10 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
module_name = target_class.__module__
class_name = target_class.__name__
# Use a random uuid as function hash to solve actor name conflict.
return cls(module_name, "__init__", class_name, str(uuid.uuid4()))
return cls(
module_name, "__init__", class_name,
hashlib.shake_128(
uuid.uuid4().bytes).hexdigest(ray_constants.ID_SIZE))
@property
def module_name(self):
@@ -268,14 +273,14 @@ cdef class PythonFunctionDescriptor(FunctionDescriptor):
Returns:
ray.ObjectRef to represent the function descriptor.
"""
function_id_hash = hashlib.sha1()
function_id_hash = hashlib.shake_128()
# Include the function module and name in the hash.
function_id_hash.update(self.typed_descriptor.ModuleName())
function_id_hash.update(self.typed_descriptor.FunctionName())
function_id_hash.update(self.typed_descriptor.ClassName())
function_id_hash.update(self.typed_descriptor.FunctionHash())
# Compute the function ID.
function_id = function_id_hash.digest()
function_id = function_id_hash.digest(ray_constants.ID_SIZE)
return ray.FunctionID(function_id)
def is_actor_method(self):
+4 -3
View File
@@ -179,9 +179,10 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
c_bool plasma_objects_only)
CRayStatus Contains(const CObjectID &object_id, c_bool *has_object)
CRayStatus Wait(const c_vector[CObjectID] &object_ids, int num_objects,
int64_t timeout_ms, c_vector[c_bool] *results)
int64_t timeout_ms, c_vector[c_bool] *results,
c_bool fetch_local)
CRayStatus Delete(const c_vector[CObjectID] &object_ids,
c_bool local_only, c_bool delete_creating_tasks)
c_bool local_only)
CRayStatus TriggerGlobalGC()
c_string MemoryUsageString()
@@ -232,7 +233,7 @@ cdef extern from "ray/core_worker/core_worker.h" nogil:
(CRayStatus() nogil) check_signals
(void() nogil) gc_collect
(c_vector[c_string](const c_vector[CObjectID] &) nogil) spill_objects
(void(
(int64_t(
const c_vector[CObjectID] &,
const c_vector[c_string] &) nogil) restore_spilled_objects
(void(
+1 -1
View File
@@ -31,7 +31,7 @@ def check_id(b, size=kUniqueIDSize):
raise TypeError("Unsupported type: " + str(type(b)))
if len(b) != size:
raise ValueError("ID string needs to have length " +
str(size))
str(size) + ", got " + str(len(b)))
cdef extern from "ray/common/constants.h" nogil:
+2 -5
View File
@@ -37,7 +37,7 @@ def memory_summary():
return reply.memory_summary
def free(object_refs, local_only=False, delete_creating_tasks=False):
def free(object_refs, local_only=False):
"""Free a list of IDs from the in-process and plasma object stores.
This function is a low-level API which should be used in restricted
@@ -59,8 +59,6 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
object_refs (List[ObjectRef]): List of object refs to delete.
local_only (bool): Whether only deleting the list of objects in local
object store or all object stores.
delete_creating_tasks (bool): Whether also delete the object creating
tasks.
"""
worker = ray.worker.global_worker
@@ -83,5 +81,4 @@ def free(object_refs, local_only=False, delete_creating_tasks=False):
if len(object_refs) == 0:
return
worker.core_worker.free_objects(object_refs, local_only,
delete_creating_tasks)
worker.core_worker.free_objects(object_refs, local_only)
+1 -1
View File
@@ -22,7 +22,7 @@ from ray.ray_logging import setup_component_logger
logger = logging.getLogger(__name__)
# The groups are worker id, job id, and pid.
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]{40})-(\d+)-(\d+)")
JOB_LOG_PATTERN = re.compile(".*worker-([0-9a-f]+)-(\d+)-(\d+)")
class LogFileInfo:
+19 -15
View File
@@ -15,11 +15,14 @@ from ray.autoscaler._private.constants import AUTOSCALER_UPDATE_INTERVAL_S
from ray.autoscaler._private.load_metrics import LoadMetrics
from ray.autoscaler._private.constants import \
AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE
from ray.autoscaler._private.util import DEBUG_AUTOSCALING_STATUS
import ray.gcs_utils
import ray.utils
import ray.ray_constants as ray_constants
from ray.ray_logging import setup_component_logger
from ray._raylet import GlobalStateAccessor
from ray.experimental.internal_kv import _internal_kv_put, \
_internal_kv_initialized
import redis
@@ -65,11 +68,7 @@ def parse_resource_demands(resource_load_by_shape):
except Exception:
logger.exception("Failed to parse resource demands.")
# Bound the total number of bundles to 2xMAX_RESOURCE_DEMAND_VECTOR_SIZE.
# This guarantees the resource demand scheduler bin packing algorithm takes
# a reasonable amount of time to run.
return waiting_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE], \
infeasible_bundles[:AUTOSCALER_MAX_RESOURCE_DEMAND_VECTOR_SIZE]
return waiting_bundles, infeasible_bundles
class Monitor:
@@ -184,14 +183,8 @@ class Monitor:
data: a resource request as JSON, e.g. {"CPU": 1}
"""
if not self.autoscaler:
return
try:
self.autoscaler.request_resources(json.loads(data))
except Exception:
# We don't want this to kill the monitor.
traceback.print_exc()
resource_request = json.loads(data)
self.load_metrics.set_resource_requests(resource_request)
def process_messages(self, max_messages=10000):
"""Process all messages ready in the subscription channels.
@@ -257,12 +250,23 @@ class Monitor:
# Handle messages from the subscription channels.
while True:
self.update_raylet_map()
self.update_load_metrics()
status = {
"load_metrics_report": self.load_metrics.summary()._asdict()
}
# Process autoscaling actions
if self.autoscaler:
# Only used to update the load metrics for the autoscaler.
self.update_raylet_map()
self.update_load_metrics()
self.autoscaler.update()
status[
"autoscaler_report"] = self.autoscaler.summary()._asdict()
as_json = json.dumps(status)
if _internal_kv_initialized():
_internal_kv_put(
DEBUG_AUTOSCALING_STATUS, as_json, overwrite=True)
# Process a round of messages.
self.process_messages()
+1 -1
View File
@@ -19,7 +19,7 @@ def env_bool(key, default):
return default
ID_SIZE = 20
ID_SIZE = 28
# The default maximum number of bytes to allocate to the object store unless
# overridden by the user.
+27
View File
@@ -1,8 +1,11 @@
import logging
import os
import sys
import threading
from logging.handlers import RotatingFileHandler
from typing import Callable
import ray
from ray.utils import binary_to_hex
@@ -258,3 +261,27 @@ def setup_and_get_worker_interceptor_logger(args,
# logger to add a newline at the end of string.
handler.terminator = ""
return logger
class WorkerStandardStreamDispatcher:
def __init__(self):
self.handlers = []
self._lock = threading.Lock()
def add_handler(self, name: str, handler: Callable) -> None:
with self._lock:
self.handlers.append((name, handler))
def remove_handler(self, name: str) -> None:
with self._lock:
new_handlers = [pair for pair in self.handlers if pair[0] != name]
self.handlers = new_handlers
def emit(self, data):
with self._lock:
for pair in self.handlers:
_, handle = pair
handle(data)
global_worker_stdstream_dispatcher = WorkerStandardStreamDispatcher()
+6 -26
View File
@@ -2,17 +2,15 @@
import asyncio
import logging
import os
import time
from ray._private.ray_microbenchmark_helpers import timeit
from ray._private.ray_client_microbenchmark import (main as
client_microbenchmark_main)
import numpy as np
import multiprocessing
import ray
logger = logging.getLogger(__name__)
# Only run tests matching this filter pattern.
filter_pattern = os.environ.get("TESTS_TO_RUN", "")
@ray.remote(num_cpus=0)
class Actor:
@@ -71,27 +69,6 @@ def small_value_batch(n):
return 0
def timeit(name, fn, multiplier=1):
if filter_pattern not in name:
return
# warmup
start = time.time()
while time.time() - start < 1:
fn()
# real run
stats = []
for _ in range(4):
start = time.time()
count = 0
while time.time() - start < 2:
fn()
count += 1
end = time.time()
stats.append(multiplier * count / (end - start))
print(name, "per second", round(np.mean(stats), 2), "+-",
round(np.std(stats), 2))
def check_optimized_build():
if not ray._raylet.OPTIMIZED:
msg = ("WARNING: Unoptimized build! "
@@ -277,6 +254,9 @@ def main():
ray.get([async_actor_work.remote(a) for _ in range(m)])
timeit("n:n async-actor calls async", async_actor_multi, m * n)
ray.shutdown()
client_microbenchmark_main()
if __name__ == "__main__":
+2 -5
View File
@@ -6,7 +6,6 @@ import logging
import os
import subprocess
import sys
from telnetlib import Telnet
import time
import urllib
import urllib.parse
@@ -172,8 +171,7 @@ def continue_debug_session():
ray.experimental.internal_kv._internal_kv_del(key)
return
host, port = session["pdb_address"].split(":")
with Telnet(host, int(port)) as tn:
tn.interact()
ray.util.rpdb.connect_pdb_client(host, int(port))
ray.experimental.internal_kv._internal_kv_del(key)
continue_debug_session()
return
@@ -215,8 +213,7 @@ def debug(address):
ray.experimental.internal_kv._internal_kv_get(
active_sessions[index]))
host, port = session["pdb_address"].split(":")
with Telnet(host, int(port)) as tn:
tn.interact()
ray.util.rpdb.connect_pdb_client(host, int(port))
@cli.command()
+3 -2
View File
@@ -74,7 +74,8 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
new_class_id = pickle.dumps(pickle.loads(class_id))
if new_class_id == class_id:
# We appear to have reached a fix point, so use this as the ID.
return hashlib.sha1(new_class_id).digest()
return hashlib.shake_128(new_class_id).digest(
ray_constants.ID_SIZE)
class_id = new_class_id
# We have not reached a fixed point, so we may end up with a different
@@ -82,7 +83,7 @@ def _try_to_compute_deterministic_class_id(cls, depth=5):
# same class definition being exported many many times.
logger.warning(
f"WARNING: Could not produce a deterministic class ID for class {cls}")
return hashlib.sha1(new_class_id).digest()
return hashlib.shake_128(new_class_id).digest(ray_constants.ID_SIZE)
def object_ref_deserializer(reduced_obj_ref, owner_address):
+8
View File
@@ -119,6 +119,14 @@ py_test(
deps = [":serve_lib"],
)
py_test(
name = "test_imported_backend",
size = "small",
srcs = serve_tests_srcs,
tags = ["exclusive"],
deps = [":serve_lib"],
)
# Runs test_api and test_failure with injected failures in the controller.
# TODO(simon): Tests are disabled until #11683 is fixed.
+93 -28
View File
@@ -4,22 +4,41 @@ import time
from functools import wraps
import os
from uuid import UUID
import threading
from typing import Any, Callable, Coroutine, Dict, List, Optional, Type, Union
from ray.serve.context import TaskContext
import ray
from ray.serve.constants import (DEFAULT_HTTP_HOST, DEFAULT_HTTP_PORT,
SERVE_CONTROLLER_NAME, HTTP_PROXY_TIMEOUT)
from ray.serve.controller import ServeController
from ray.serve.handle import RayServeHandle
from ray.serve.handle import RayServeHandle, RayServeSyncHandle
from ray.serve.utils import (block_until_http_ready, format_actor_name,
get_random_letters, logger, get_conda_env_dir)
from ray.serve.exceptions import RayServeException
from ray.serve.config import BackendConfig, ReplicaConfig, BackendMetadata
from ray.serve.config import (BackendConfig, ReplicaConfig, BackendMetadata,
HTTPConfig)
from ray.serve.env import CondaEnv
from ray.serve.router import RequestMetadata, Router
from ray.actor import ActorHandle
from typing import Any, Callable, Dict, List, Optional, Type, Union
_INTERNAL_CONTROLLER_NAME = None
global_async_loop = None
def create_or_get_async_loop_in_thread():
global global_async_loop
if global_async_loop is None:
global_async_loop = asyncio.new_event_loop()
thread = threading.Thread(
daemon=True,
target=global_async_loop.run_forever,
)
thread.start()
return global_async_loop
def _set_internal_controller_name(name):
global _INTERNAL_CONTROLLER_NAME
@@ -36,6 +55,36 @@ def _ensure_connected(f: Callable) -> Callable:
return check
class ThreadProxiedRouter:
def __init__(self, controller_handle, sync: bool):
self.router = Router(controller_handle)
if sync:
self.async_loop = create_or_get_async_loop_in_thread()
asyncio.run_coroutine_threadsafe(
self.router.setup_in_async_loop(),
self.async_loop,
)
else:
self.async_loop = asyncio.get_event_loop()
self.async_loop.create_task(self.router.setup_in_async_loop())
def _remote(self, endpoint_name, handle_options, request_data,
kwargs) -> Coroutine:
request_metadata = RequestMetadata(
get_random_letters(10), # Used for debugging.
endpoint_name,
TaskContext.Python,
call_method=handle_options.method_name,
shard_key=handle_options.shard_key,
http_method=handle_options.http_method,
http_headers=handle_options.http_headers,
)
coro = self.router.assign_request(request_metadata, request_data,
**kwargs)
return coro
class Client:
def __init__(self,
controller: ActorHandle,
@@ -45,15 +94,10 @@ class Client:
self._controller_name = controller_name
self._detached = detached
self._shutdown = False
self._http_host, self._http_port = ray.get(
controller.get_http_config.remote())
self._http_config = ray.get(controller.get_http_config.remote())
# NOTE(simon): Used to cache client.get_handle(endpoint) call. It will
# mostly grow in size, it will only shrink when user calls the
# .remove_endpoint method. This is fine because we expect the number of
# endpoints to be fairly small. However, in case this dictionary does
# grow very big, we can replace it with a LRU cache instead.
self._handle_cache: Dict[str, ActorHandle] = dict()
self._sync_proxied_router = None
self._async_proxied_router = None
# NOTE(edoakes): Need this because the shutdown order isn't guaranteed
# when the interpreter is exiting so we can't rely on __del__ (it
@@ -65,6 +109,18 @@ class Client:
atexit.register(shutdown_serve_client)
def _get_proxied_router(self, sync: bool):
if sync:
if self._sync_proxied_router is None:
self._sync_proxied_router = ThreadProxiedRouter(
self._controller, sync=True)
return self._sync_proxied_router
else:
if self._async_proxied_router is None:
self._async_proxied_router = ThreadProxiedRouter(
self._controller, sync=False)
return self._async_proxied_router
def __del__(self):
if not self._detached:
logger.debug("Shutting down Ray Serve because client went out of "
@@ -181,8 +237,8 @@ class Client:
num_cpus=0, resources={
node_id: 0.01
}).remote(
"http://{}:{}/-/routes".format(self._http_host,
self._http_port),
"http://{}:{}/-/routes".format(self._http_config.host,
self._http_config.port),
check_ready=check_ready,
timeout=HTTP_PROXY_TIMEOUT)
futures.append(future)
@@ -198,8 +254,6 @@ class Client:
Does not delete any associated backends.
"""
if endpoint in self._handle_cache:
del self._handle_cache[endpoint]
self._get_result(self._controller.delete_endpoint.remote(endpoint))
@_ensure_connected
@@ -410,10 +464,11 @@ class Client:
proportion))
@_ensure_connected
def get_handle(self,
endpoint_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True) -> RayServeHandle:
def get_handle(
self,
endpoint_name: str,
missing_ok: Optional[bool] = False,
sync: bool = True) -> Union[RayServeHandle, RayServeSyncHandle]:
"""Retrieve RayServeHandle for service endpoint to invoke it from Python.
Args:
@@ -433,14 +488,26 @@ class Client:
if asyncio.get_event_loop().is_running() and sync:
logger.warning(
"You are retrieving a ServeHandle inside an asyncio loop. "
"You are retrieving a sync handle inside an asyncio loop. "
"Try getting client.get_handle(.., sync=False) to get better "
"performance.")
"performance. Learn more at https://docs.ray.io/en/master/"
"serve/advanced.html#sync-and-async-handles")
if endpoint_name not in self._handle_cache:
handle = RayServeHandle(self._controller, endpoint_name, sync=sync)
self._handle_cache[endpoint_name] = handle
return self._handle_cache[endpoint_name]
if not asyncio.get_event_loop().is_running() and not sync:
logger.warning(
"You are retrieving an async handle outside an asyncio loop. "
"You should make sure client.get_handle is called inside a "
"running event loop. Or call client.get_handle(.., sync=True) "
"to create sync handle. Learn more at https://docs.ray.io/en/"
"master/serve/advanced.html#sync-and-async-handles")
if sync:
handle = RayServeSyncHandle(
self._get_proxied_router(sync=sync), endpoint_name)
else:
handle = RayServeHandle(
self._get_proxied_router(sync=sync), endpoint_name)
return handle
def start(detached: bool = False,
@@ -492,9 +559,7 @@ def start(detached: bool = False,
max_task_retries=-1,
).remote(
controller_name,
http_host,
http_port,
http_middlewares,
HTTPConfig(http_host, http_port, http_middlewares),
detached=detached)
if http_host is not None:
+10 -10
View File
@@ -186,10 +186,10 @@ class RayServeReplica:
"backend_replica_starts",
description=("The number of time this replica "
"has been restarted due to failure."),
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.restart_counter.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.queuing_latency_tracker = metrics.Histogram(
@@ -198,39 +198,39 @@ class RayServeReplica:
"The latency for queries waiting in the replica's queue "
"waiting to be processed or batched."),
boundaries=DEFAULT_LATENCY_BUCKET_MS,
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.queuing_latency_tracker.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.processing_latency_tracker = metrics.Histogram(
"backend_processing_latency_ms",
description="The latency for queries to be processed",
boundaries=DEFAULT_LATENCY_BUCKET_MS,
tag_keys=("backend", "replica_tag", "batch_size"))
tag_keys=("backend", "replica", "batch_size"))
self.processing_latency_tracker.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.num_queued_items = metrics.Gauge(
"replica_queued_queries",
description=("Current number of queries queued in the "
"the backend replicas"),
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.num_queued_items.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.num_processing_items = metrics.Gauge(
"replica_processing_queries",
description="Current number of queries being processed",
tag_keys=("backend", "replica_tag"))
tag_keys=("backend", "replica"))
self.num_processing_items.set_default_tags({
"backend": self.backend_tag,
"replica_tag": self.replica_tag
"replica": self.replica_tag
})
self.restart_counter.record(1)
+33
View File
@@ -0,0 +1,33 @@
from ray.serve.utils import import_class
class ImportedBackend:
"""Factory for a class that will dynamically import a backend class.
This is intended to be used when the source code for a backend is
installed in the worker environment but not the driver.
Intended usage:
>>> client = serve.connect()
>>> client.create_backend("b", ImportedBackend("module.Class"), *args)
This will import module.Class on the worker and proxy all relevant methods
to it.
"""
def __new__(cls, class_path):
class ImportedBackend:
def __init__(self, *args, **kwargs):
self.wrapped = import_class(class_path)(*args, **kwargs)
def reconfigure(self, *args, **kwargs):
# NOTE(edoakes): we check that the reconfigure method is
# present if the user specifies a user_config, so we need to
# proxy it manually.
return self.wrapped.reconfigure(*args, **kwargs)
def __getattr__(self, attr):
"""Proxy all other methods to the wrapper class."""
return getattr(self.wrapped, attr)
return ImportedBackend

Some files were not shown because too many files have changed in this diff Show More