mirror of
https://github.com/wassname/ray.git
synced 2026-07-05 10:34:05 +08:00
Support ray task type checking (#9574)
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from typing import Any, Awaitable
|
||||
|
||||
|
||||
class ObjectRef(Awaitable[Any]):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectID(Awaitable[Any]):
|
||||
pass
|
||||
@@ -1,14 +1,15 @@
|
||||
import copy
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import logging
|
||||
import sys
|
||||
import click
|
||||
import random
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
try: # py3
|
||||
from shlex import quote
|
||||
@@ -84,9 +85,11 @@ def request_resources(num_cpus=None, bundles=None):
|
||||
r.publish(AUTOSCALER_RESOURCE_REQUEST_CHANNEL, json.dumps(bundles))
|
||||
|
||||
|
||||
def create_or_update_cluster(config_file, override_min_workers,
|
||||
override_max_workers, no_restart, restart_only,
|
||||
yes, override_cluster_name, no_config_cache):
|
||||
def create_or_update_cluster(
|
||||
config_file: str, override_min_workers: Optional[int],
|
||||
override_max_workers: Optional[int], no_restart: bool,
|
||||
restart_only: bool, yes: bool, override_cluster_name: Optional[str],
|
||||
no_config_cache: bool) -> None:
|
||||
"""Create or updates an autoscaling Ray cluster from a config json."""
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
if override_min_workers is not None:
|
||||
@@ -100,7 +103,8 @@ def create_or_update_cluster(config_file, override_min_workers,
|
||||
override_cluster_name)
|
||||
|
||||
|
||||
def _bootstrap_config(config, no_config_cache=False):
|
||||
def _bootstrap_config(config: Dict[str, Any],
|
||||
no_config_cache: bool = False) -> Dict[str, Any]:
|
||||
config = prepare_config(config)
|
||||
|
||||
hasher = hashlib.sha1()
|
||||
@@ -125,8 +129,9 @@ def _bootstrap_config(config, no_config_cache=False):
|
||||
return resolved_config
|
||||
|
||||
|
||||
def teardown_cluster(config_file, yes, workers_only, override_cluster_name,
|
||||
keep_min_workers):
|
||||
def teardown_cluster(config_file: str, yes: bool, workers_only: bool,
|
||||
override_cluster_name: Optional[str],
|
||||
keep_min_workers: bool):
|
||||
"""Destroys all nodes of a Ray cluster described by a config json."""
|
||||
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
@@ -410,8 +415,9 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes,
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def attach_cluster(config_file, start, use_screen, use_tmux,
|
||||
override_cluster_name, new, port_forward):
|
||||
def attach_cluster(config_file: str, start: bool, use_screen: bool,
|
||||
use_tmux: bool, override_cluster_name: Optional[str],
|
||||
new: bool, port_forward: Any):
|
||||
"""Attaches to a screen for the specified cluster.
|
||||
|
||||
Arguments:
|
||||
@@ -452,17 +458,17 @@ def attach_cluster(config_file, start, use_screen, use_tmux,
|
||||
port_forward=port_forward)
|
||||
|
||||
|
||||
def exec_cluster(config_file,
|
||||
def exec_cluster(config_file: str,
|
||||
*,
|
||||
cmd=None,
|
||||
run_env="auto",
|
||||
screen=False,
|
||||
tmux=False,
|
||||
stop=False,
|
||||
start=False,
|
||||
override_cluster_name=None,
|
||||
port_forward=None,
|
||||
with_output=False):
|
||||
cmd: Any = None,
|
||||
run_env: str = "auto",
|
||||
screen: bool = False,
|
||||
tmux: bool = False,
|
||||
stop: bool = False,
|
||||
start: bool = False,
|
||||
override_cluster_name: Optional[str] = None,
|
||||
port_forward: Any = None,
|
||||
with_output: bool = False):
|
||||
"""Runs a command on the specified cluster.
|
||||
|
||||
Arguments:
|
||||
@@ -571,12 +577,12 @@ def _exec(updater,
|
||||
run_env=run_env)
|
||||
|
||||
|
||||
def rsync(config_file,
|
||||
source,
|
||||
target,
|
||||
override_cluster_name,
|
||||
down,
|
||||
all_nodes=False):
|
||||
def rsync(config_file: str,
|
||||
source: Optional[str],
|
||||
target: Optional[str],
|
||||
override_cluster_name: Optional[str],
|
||||
down: bool,
|
||||
all_nodes: bool = False):
|
||||
"""Rsyncs files.
|
||||
|
||||
Arguments:
|
||||
@@ -639,7 +645,8 @@ def rsync(config_file,
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def get_head_node_ip(config_file, override_cluster_name):
|
||||
def get_head_node_ip(config_file: str,
|
||||
override_cluster_name: Optional[str]) -> str:
|
||||
"""Returns head node IP for given configuration file if exists."""
|
||||
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
@@ -659,7 +666,8 @@ def get_head_node_ip(config_file, override_cluster_name):
|
||||
return head_node_ip
|
||||
|
||||
|
||||
def get_worker_node_ips(config_file, override_cluster_name):
|
||||
def get_worker_node_ips(config_file: str,
|
||||
override_cluster_name: Optional[str]) -> str:
|
||||
"""Returns worker node IPs for given configuration file."""
|
||||
|
||||
config = yaml.safe_load(open(config_file).read())
|
||||
@@ -695,10 +703,10 @@ def _get_worker_nodes(config, override_cluster_name):
|
||||
provider.cleanup()
|
||||
|
||||
|
||||
def _get_head_node(config,
|
||||
config_file,
|
||||
override_cluster_name,
|
||||
create_if_needed=False):
|
||||
def _get_head_node(config: Dict[str, Any],
|
||||
config_file: str,
|
||||
override_cluster_name: Optional[str],
|
||||
create_if_needed: bool = False) -> str:
|
||||
provider = get_node_provider(config["provider"], config["cluster_name"])
|
||||
try:
|
||||
head_node_tags = {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Dict
|
||||
|
||||
import yaml
|
||||
|
||||
from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner
|
||||
@@ -102,7 +104,8 @@ def load_class(path):
|
||||
return getattr(module, class_str)
|
||||
|
||||
|
||||
def get_node_provider(provider_config, cluster_name):
|
||||
def get_node_provider(provider_config: Dict[str, Any],
|
||||
cluster_name: str) -> Any:
|
||||
importer = NODE_PROVIDERS.get(provider_config["type"])
|
||||
if importer is None:
|
||||
raise NotImplementedError("Unsupported node provider: {}".format(
|
||||
@@ -225,7 +228,7 @@ class NodeProvider:
|
||||
cluster_name,
|
||||
process_runner,
|
||||
use_internal_ip,
|
||||
docker_config=None):
|
||||
docker_config=None) -> Any:
|
||||
""" Returns the CommandRunner class used to perform SSH commands.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import collections
|
||||
import os
|
||||
import json
|
||||
import threading
|
||||
import hashlib
|
||||
import json
|
||||
import jsonschema
|
||||
import os
|
||||
import threading
|
||||
from typing import Any, Dict
|
||||
|
||||
import ray
|
||||
import ray.services as services
|
||||
@@ -45,7 +46,7 @@ class ConcurrentCounter:
|
||||
return sum(self._counter.values())
|
||||
|
||||
|
||||
def validate_config(config):
|
||||
def validate_config(config: Dict[str, Any]) -> None:
|
||||
"""Required Dicts indicate that no extra fields can be introduced."""
|
||||
if not isinstance(config, dict):
|
||||
raise ValueError("Config {} is not a dictionary".format(config))
|
||||
@@ -65,7 +66,7 @@ def prepare_config(config):
|
||||
return with_defaults
|
||||
|
||||
|
||||
def fillout_defaults(config):
|
||||
def fillout_defaults(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
defaults = get_default_config(config["provider"])
|
||||
defaults.update(config)
|
||||
defaults["auth"] = defaults.get("auth", {})
|
||||
|
||||
@@ -42,7 +42,7 @@ def _get_actor(name):
|
||||
return handle
|
||||
|
||||
|
||||
def get_actor(name):
|
||||
def get_actor(name: str) -> ray.actor.ActorHandle:
|
||||
"""Get a named actor which was previously created.
|
||||
|
||||
If the actor doesn't exist, an exception will be raised.
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload
|
||||
|
||||
from ray._raylet import ObjectRef
|
||||
|
||||
|
||||
T0 = TypeVar("T0")
|
||||
T1 = TypeVar("T1")
|
||||
T2 = TypeVar("T2")
|
||||
T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
T6 = TypeVar("T6")
|
||||
T7 = TypeVar("T7")
|
||||
T8 = TypeVar("T8")
|
||||
T9 = TypeVar("T9")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class RemoteFunction(Generic[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
|
||||
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], Any]) -> None: pass
|
||||
|
||||
@overload
|
||||
def remote(self) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef]) -> ObjectRef: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef], arg9: Union[T9, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, *args, **kwargs) -> ObjectRef:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
def remote(function: Callable[[], R]) -> RemoteFunction[None, None, None, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0], R]) -> RemoteFunction[T0, None, None, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[T0, T1, None, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[T0, T1, T2, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[T0, T1, T2, T3, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[T0, T1, T2, T3, T4, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ...
|
||||
# Pass on typing actors for now. The following makes it so no type errors are generated for actors.
|
||||
@overload
|
||||
def remote(t: type) -> Any: ...
|
||||
def remote(function: Callable[..., R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: pass
|
||||
Reference in New Issue
Block a user