Support ray task type checking (#9574)

This commit is contained in:
Philipp Moritz
2020-07-21 19:05:42 -07:00
committed by GitHub
parent 997d1162e3
commit a5f4659d9f
11 changed files with 194 additions and 43 deletions
+9
View File
@@ -0,0 +1,9 @@
from typing import Any, Awaitable
class ObjectRef(Awaitable[Any]):
pass
class ObjectID(Awaitable[Any]):
pass
+42 -34
View File
@@ -1,14 +1,15 @@
import copy
import hashlib
import json
import logging
import os
import random
import sys
import tempfile
import time
import logging
import sys
import click
import random
from typing import Any, Dict, Optional
import click
import yaml
try: # py3
from shlex import quote
@@ -84,9 +85,11 @@ def request_resources(num_cpus=None, bundles=None):
r.publish(AUTOSCALER_RESOURCE_REQUEST_CHANNEL, json.dumps(bundles))
def create_or_update_cluster(config_file, override_min_workers,
override_max_workers, no_restart, restart_only,
yes, override_cluster_name, no_config_cache):
def create_or_update_cluster(
config_file: str, override_min_workers: Optional[int],
override_max_workers: Optional[int], no_restart: bool,
restart_only: bool, yes: bool, override_cluster_name: Optional[str],
no_config_cache: bool) -> None:
"""Create or updates an autoscaling Ray cluster from a config json."""
config = yaml.safe_load(open(config_file).read())
if override_min_workers is not None:
@@ -100,7 +103,8 @@ def create_or_update_cluster(config_file, override_min_workers,
override_cluster_name)
def _bootstrap_config(config, no_config_cache=False):
def _bootstrap_config(config: Dict[str, Any],
no_config_cache: bool = False) -> Dict[str, Any]:
config = prepare_config(config)
hasher = hashlib.sha1()
@@ -125,8 +129,9 @@ def _bootstrap_config(config, no_config_cache=False):
return resolved_config
def teardown_cluster(config_file, yes, workers_only, override_cluster_name,
keep_min_workers):
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
override_cluster_name: Optional[str],
keep_min_workers: bool):
"""Destroys all nodes of a Ray cluster described by a config json."""
config = yaml.safe_load(open(config_file).read())
@@ -410,8 +415,9 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
provider.cleanup()
def attach_cluster(config_file, start, use_screen, use_tmux,
override_cluster_name, new, port_forward):
def attach_cluster(config_file: str, start: bool, use_screen: bool,
use_tmux: bool, override_cluster_name: Optional[str],
new: bool, port_forward: Any):
"""Attaches to a screen for the specified cluster.
Arguments:
@@ -452,17 +458,17 @@ def attach_cluster(config_file, start, use_screen, use_tmux,
port_forward=port_forward)
def exec_cluster(config_file,
def exec_cluster(config_file: str,
*,
cmd=None,
run_env="auto",
screen=False,
tmux=False,
stop=False,
start=False,
override_cluster_name=None,
port_forward=None,
with_output=False):
cmd: Any = None,
run_env: str = "auto",
screen: bool = False,
tmux: bool = False,
stop: bool = False,
start: bool = False,
override_cluster_name: Optional[str] = None,
port_forward: Any = None,
with_output: bool = False):
"""Runs a command on the specified cluster.
Arguments:
@@ -571,12 +577,12 @@ def _exec(updater,
run_env=run_env)
def rsync(config_file,
source,
target,
override_cluster_name,
down,
all_nodes=False):
def rsync(config_file: str,
source: Optional[str],
target: Optional[str],
override_cluster_name: Optional[str],
down: bool,
all_nodes: bool = False):
"""Rsyncs files.
Arguments:
@@ -639,7 +645,8 @@ def rsync(config_file,
provider.cleanup()
def get_head_node_ip(config_file, override_cluster_name):
def get_head_node_ip(config_file: str,
override_cluster_name: Optional[str]) -> str:
"""Returns head node IP for given configuration file if exists."""
config = yaml.safe_load(open(config_file).read())
@@ -659,7 +666,8 @@ def get_head_node_ip(config_file, override_cluster_name):
return head_node_ip
def get_worker_node_ips(config_file, override_cluster_name):
def get_worker_node_ips(config_file: str,
override_cluster_name: Optional[str]) -> str:
"""Returns worker node IPs for given configuration file."""
config = yaml.safe_load(open(config_file).read())
@@ -695,10 +703,10 @@ def _get_worker_nodes(config, override_cluster_name):
provider.cleanup()
def _get_head_node(config,
config_file,
override_cluster_name,
create_if_needed=False):
def _get_head_node(config: Dict[str, Any],
config_file: str,
override_cluster_name: Optional[str],
create_if_needed: bool = False) -> str:
provider = get_node_provider(config["provider"], config["cluster_name"])
try:
head_node_tags = {
+5 -2
View File
@@ -1,6 +1,8 @@
import importlib
import logging
import os
from typing import Any, Dict
import yaml
from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner
@@ -102,7 +104,8 @@ def load_class(path):
return getattr(module, class_str)
def get_node_provider(provider_config, cluster_name):
def get_node_provider(provider_config: Dict[str, Any],
cluster_name: str) -> Any:
importer = NODE_PROVIDERS.get(provider_config["type"])
if importer is None:
raise NotImplementedError("Unsupported node provider: {}".format(
@@ -225,7 +228,7 @@ class NodeProvider:
cluster_name,
process_runner,
use_internal_ip,
docker_config=None):
docker_config=None) -> Any:
""" Returns the CommandRunner class used to perform SSH commands.
Args:
+6 -5
View File
@@ -1,9 +1,10 @@
import collections
import os
import json
import threading
import hashlib
import json
import jsonschema
import os
import threading
from typing import Any, Dict
import ray
import ray.services as services
@@ -45,7 +46,7 @@ class ConcurrentCounter:
return sum(self._counter.values())
def validate_config(config):
def validate_config(config: Dict[str, Any]) -> None:
"""Required Dicts indicate that no extra fields can be introduced."""
if not isinstance(config, dict):
raise ValueError("Config {} is not a dictionary".format(config))
@@ -65,7 +66,7 @@ def prepare_config(config):
return with_defaults
def fillout_defaults(config):
def fillout_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
defaults = get_default_config(config["provider"])
defaults.update(config)
defaults["auth"] = defaults.get("auth", {})
View File
+1 -1
View File
@@ -42,7 +42,7 @@ def _get_actor(name):
return handle
def get_actor(name):
def get_actor(name: str) -> ray.actor.ActorHandle:
"""Get a named actor which was previously created.
If the actor doesn't exist, an exception will be raised.
+73
View File
@@ -0,0 +1,73 @@
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload
from ray._raylet import ObjectRef
T0 = TypeVar("T0")
T1 = TypeVar("T1")
T2 = TypeVar("T2")
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
T8 = TypeVar("T8")
T9 = TypeVar("T9")
R = TypeVar("R")
class RemoteFunction(Generic[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], Any]) -> None: pass
@overload
def remote(self) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef]) -> ObjectRef: ...
@overload
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef], arg9: Union[T9, ObjectRef]) -> ObjectRef: ...
def remote(self, *args, **kwargs) -> ObjectRef:
pass
@overload
def remote(function: Callable[[], R]) -> RemoteFunction[None, None, None, None, None, None, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0], R]) -> RemoteFunction[T0, None, None, None, None, None, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[T0, T1, None, None, None, None, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[T0, T1, T2, None, None, None, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[T0, T1, T2, T3, None, None, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[T0, T1, T2, T3, T4, None, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, None, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, None, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ...
@overload
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ...
# Pass on typing actors for now. The following makes it so no type errors are generated for actors.
@overload
def remote(t: type) -> Any: ...
def remote(function: Callable[..., R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: pass