diff --git a/python/ray/experimental/serve/api.py b/python/ray/experimental/serve/api.py index 14ed208cf..01bff7e31 100644 --- a/python/ray/experimental/serve/api.py +++ b/python/ray/experimental/serve/api.py @@ -106,9 +106,9 @@ def _start_replica(backend_tag): runner = creator() setup_done = runner._ray_serve_setup.remote( - backend_tag, global_state.router_actor_handle) + backend_tag, global_state.router_actor_handle, runner) ray.get(setup_done) - runner._ray_serve_main_loop.remote(runner) + runner._ray_serve_main_loop.remote() global_state.backend_replicas[backend_tag].append(runner) global_state.metric_monitor_handle.add_target.remote(runner) diff --git a/python/ray/experimental/serve/task_runner.py b/python/ray/experimental/serve/task_runner.py index d0fa4c7de..af072e131 100644 --- a/python/ray/experimental/serve/task_runner.py +++ b/python/ray/experimental/serve/task_runner.py @@ -79,14 +79,14 @@ class RayServeMixin: }, } - def _ray_serve_setup(self, my_name, _ray_serve_router_handle): + def _ray_serve_setup(self, my_name, router_handle, my_handle): self._ray_serve_dequeue_requestr_name = my_name - self._ray_serve_router_handle = _ray_serve_router_handle + self._ray_serve_router_handle = router_handle + self._ray_serve_self_handle = my_handle self._ray_serve_setup_completed = True - def _ray_serve_main_loop(self, my_handle): + def _ray_serve_main_loop(self): assert self._ray_serve_setup_completed - self._ray_serve_self_handle = my_handle # Only retrieve the next task if we have completed previous task. if self._ray_serve_cached_work_token is None: @@ -104,7 +104,7 @@ class RayServeMixin: self._ray_serve_cached_work_token = None else: self._ray_serve_cached_work_token = work_token - self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle) + self._ray_serve_self_handle._ray_serve_main_loop.remote() return if work_item.request_context == TaskContext.Web: @@ -136,7 +136,7 @@ class RayServeMixin: # It will now tail recursively schedule the main_loop again. # TODO(simon): remove tail recursion, ask router to callback instead - self._ray_serve_self_handle._ray_serve_main_loop.remote(my_handle) + self._ray_serve_self_handle._ray_serve_main_loop.remote() class TaskRunnerBackend(TaskRunner, RayServeMixin): diff --git a/python/ray/experimental/serve/tests/test_task_runner.py b/python/ray/experimental/serve/tests/test_task_runner.py index e26fa4434..7746a93df 100644 --- a/python/ray/experimental/serve/tests/test_task_runner.py +++ b/python/ray/experimental/serve/tests/test_task_runner.py @@ -34,8 +34,8 @@ def test_runner_actor(serve_instance): runner = TaskRunnerActor.remote(echo) - runner._ray_serve_setup.remote(CONSUMER_NAME, q) - runner._ray_serve_main_loop.remote(runner) + runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) + runner._ray_serve_main_loop.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -69,8 +69,8 @@ def test_ray_serve_mixin(serve_instance): runner = CustomActor.remote(3) - runner._ray_serve_setup.remote(CONSUMER_NAME, q) - runner._ray_serve_main_loop.remote(runner) + runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) + runner._ray_serve_main_loop.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) @@ -97,8 +97,8 @@ def test_task_runner_check_context(serve_instance): runner = TaskRunnerActor.remote(echo) - runner._ray_serve_setup.remote(CONSUMER_NAME, q) - runner._ray_serve_main_loop.remote(runner) + runner._ray_serve_setup.remote(CONSUMER_NAME, q, runner) + runner._ray_serve_main_loop.remote() q.link.remote(PRODUCER_NAME, CONSUMER_NAME) result_token = ray.ObjectID(