diff --git a/BUILD.bazel b/BUILD.bazel index a91afeccf..7fd852b6c 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -98,6 +98,11 @@ proto_library( deps = [":common_proto"], ) +python_grpc_compile( + name = "core_worker_py_proto", + deps = [":core_worker_proto"], +) + cc_proto_library( name = "worker_cc_proto", deps = ["core_worker_proto"], @@ -1163,6 +1168,7 @@ filegroup( name = "all_py_proto", srcs = [ "common_py_proto", + "core_worker_py_proto", "gcs_py_proto", "node_manager_py_proto", "reporter_py_proto", @@ -1235,6 +1241,8 @@ genrule( sed -i -E 's/from src.ray.protobuf/from ./' "$$WORK_DIR/python/ray/core/generated/node_manager_pb2_grpc.py" && sed -i -E 's/from src.ray.protobuf/from ./' "$$WORK_DIR/python/ray/core/generated/reporter_pb2.py" && sed -i -E 's/from src.ray.protobuf/from ./' "$$WORK_DIR/python/ray/core/generated/reporter_pb2_grpc.py" && + sed -i -E 's/from src.ray.protobuf/from ./' "$$WORK_DIR/python/ray/core/generated/core_worker_pb2.py" && + sed -i -E 's/from src.ray.protobuf/from ./' "$$WORK_DIR/python/ray/core/generated/core_worker_pb2_grpc.py" && echo "$$WORK_DIR" > $@ """, local = 1, diff --git a/python/ray/dashboard/dashboard.py b/python/ray/dashboard/dashboard.py index 2cb86cef5..53bd772d5 100644 --- a/python/ray/dashboard/dashboard.py +++ b/python/ray/dashboard/dashboard.py @@ -30,6 +30,8 @@ from ray.core.generated import node_manager_pb2 from ray.core.generated import node_manager_pb2_grpc from ray.core.generated import reporter_pb2 from ray.core.generated import reporter_pb2_grpc +from ray.core.generated import core_worker_pb2 +from ray.core.generated import core_worker_pb2_grpc import ray.ray_constants as ray_constants # Logger for this module. It should be configured at the entry point @@ -269,6 +271,13 @@ class Dashboard(object): return aiohttp.web.json_response( self.raylet_stats.get_profiling_info(profiling_id)) + async def kill_actor(req) -> aiohttp.web.Response: + actor_id = req.query.get("actor_id") + ip_address = req.query.get("ip_address") + port = req.query.get("port") + return await json_response( + self.raylet_stats.kill_actor(actor_id, ip_address, port)) + async def logs(req) -> aiohttp.web.Response: hostname = req.query.get("hostname") pid = req.query.get("pid") @@ -307,6 +316,7 @@ class Dashboard(object): self.app.router.add_get("/api/check_profiling_status", check_profiling_status) self.app.router.add_get("/api/get_profiling_info", get_profiling_info) + self.app.router.add_get("/api/kill_actor", kill_actor) self.app.router.add_get("/api/logs", logs) self.app.router.add_get("/api/errors", errors) @@ -656,6 +666,19 @@ class RayletStats(threading.Thread): assert profiling_stats, "profiling not finished" return json.loads(profiling_stats.profiling_stats) + def kill_actor(self, actor_id, ip_address, port): + channel = grpc.insecure_channel("{}:{}".format(ip_address, int(port))) + stub = core_worker_pb2_grpc.CoreWorkerServiceStub(channel) + + def _callback(reply_future): + _ = reply_future.result() + + reply_future = stub.KillActor.future( + core_worker_pb2.KillActorRequest( + intended_actor_id=ray.utils.hex_to_binary(actor_id))) + reply_future.add_done_callback(_callback) + return {} + def run(self): counter = 0 while True: diff --git a/python/ray/tests/test_metrics.py b/python/ray/tests/test_metrics.py index b61c5adc9..d9ef78857 100644 --- a/python/ray/tests/test_metrics.py +++ b/python/ray/tests/test_metrics.py @@ -14,7 +14,7 @@ from ray.test_utils import RayTestTimeoutException def test_worker_stats(shutdown_only): - ray.init(num_cpus=1, include_webui=False) + addresses = ray.init(num_cpus=1, include_webui=True) raylet = ray.nodes()[0] num_cpus = raylet["Resources"]["CPU"] raylet_address = "{}:{}".format(raylet["NodeManagerAddress"], @@ -112,6 +112,39 @@ def test_worker_stats(shutdown_only): or "travis" in process) break + # Test kill_actor. + def actor_killed(PID): + """Check For the existence of a unix pid.""" + try: + os.kill(PID, 0) + except OSError: + return True + else: + return False + + webui_url = addresses["webui_url"] + webui_url = webui_url.replace("localhost", "http://127.0.0.1") + for worker in reply.workers_stats: + if worker.is_driver: + continue + requests.get( + webui_url + "/api/kill_actor", + params={ + "actor_id": ray.utils.binary_to_hex( + worker.core_worker_stats.actor_id), + "ip_address": worker.core_worker_stats.ip_address, + "port": worker.core_worker_stats.port + }) + timeout_seconds = 20 + start_time = time.time() + while True: + if time.time() - start_time > timeout_seconds: + raise RayTestTimeoutException("Timed out while killing actors") + if all( + actor_killed(worker.pid) for worker in reply.workers_stats + if not worker.is_driver): + break + def test_raylet_info_endpoint(shutdown_only): addresses = ray.init(include_webui=True, num_cpus=6)