diff --git a/python/ray/autoscaler/command_runner.py b/python/ray/autoscaler/command_runner.py index c025b0979..0249c856b 100644 --- a/python/ray/autoscaler/command_runner.py +++ b/python/ray/autoscaler/command_runner.py @@ -381,7 +381,7 @@ class SSHCommandRunner(CommandRunnerInterface): else: return self.provider.external_ip(self.node_id) - def wait_for_ip(self, deadline): + def _wait_for_ip(self, deadline): # if we have IP do not print waiting info ip = self._get_node_ip() if ip is not None: @@ -413,7 +413,7 @@ class SSHCommandRunner(CommandRunnerInterface): # I think that's reasonable. deadline = time.time() + NODE_START_WAIT_S with LogTimer(self.log_prefix + "Got IP"): - ip = self.wait_for_ip(deadline) + ip = self._wait_for_ip(deadline) cli_logger.doassert(ip is not None, "Could not get node IP.") # todo: msg diff --git a/python/ray/tests/test_command_runner.py b/python/ray/tests/test_command_runner.py index 2e747d882..977b2e299 100644 --- a/python/ray/tests/test_command_runner.py +++ b/python/ray/tests/test_command_runner.py @@ -1,7 +1,8 @@ import pytest from ray.tests.test_autoscaler import MockProvider, MockProcessRunner -from ray.autoscaler.command_runner import SSHCommandRunner, \ - _with_environment_variables, DockerCommandRunner, KubernetesCommandRunner +from ray.autoscaler.command_runner import CommandRunnerInterface, \ + SSHCommandRunner, _with_environment_variables, DockerCommandRunner, \ + KubernetesCommandRunner from ray.autoscaler.docker import DOCKER_MOUNT_PREFIX from getpass import getuser import hashlib @@ -27,6 +28,33 @@ def test_environment_variable_encoder_dict(): assert res == expected +def test_command_runner_interface_abstraction_violation(): + """Enforces the CommandRunnerInterface functions on the subclasses. + + This is important to make sure the subclasses do not violate the + function abstractions. If you need to add a new function to one of + the CommandRunnerInterface subclasses, you have to add it to + CommandRunnerInterface and all of its subclasses. + """ + + cmd_runner_interface_public_functions = dir(CommandRunnerInterface) + allowed_public_interface_functions = { + func + for func in cmd_runner_interface_public_functions + if not func.startswith("_") + } + for subcls in [ + SSHCommandRunner, DockerCommandRunner, KubernetesCommandRunner + ]: + subclass_available_functions = dir(subcls) + subclass_public_functions = { + func + for func in subclass_available_functions + if not func.startswith("_") + } + assert allowed_public_interface_functions == subclass_public_functions + + def test_ssh_command_runner(): process_runner = MockProcessRunner() provider = MockProvider()