diff --git a/.travis.yml b/.travis.yml index ff8ad8be9..1b60e06c3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -291,6 +291,10 @@ script: # ray operator tests - (cd deploy/ray-operator && export CC=gcc && suppress_output go build && suppress_output go test ./...) + # test ray typing + - mypy --strict ./ci/travis/check_typing_good.py + - mypy --strict ./ci/travis/check_typing_bad.py && return 1 || return 0 + # bazel python tests. This should be run last to keep its logs at the end of travis logs. - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci --test_tag_filters=-jenkins_only python/ray/tests/...; fi - if [ $RAY_CI_TUNE_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci --test_tag_filters=-jenkins_only python/ray/tune/...; fi diff --git a/ci/travis/check_typing_bad.py b/ci/travis/check_typing_bad.py new file mode 100644 index 000000000..3df119cfe --- /dev/null +++ b/ci/travis/check_typing_bad.py @@ -0,0 +1,26 @@ +import ray + + +ray.init() + + +@ray.remote +def f(a: int) -> str: + return "a = {}".format(a + 1) + + +@ray.remote +def g(s: str) -> str: + return s + " world" + + +@ray.remote +def h(a: str, b: int) -> str: + return a + + +# Does not typecheck: +a = h.remote(1, 1) +b = f.remote("hello") +c = f.remote(1, 1) +d = f.remote(1) + 1 diff --git a/ci/travis/check_typing_good.py b/ci/travis/check_typing_good.py new file mode 100644 index 000000000..d83bf1ce3 --- /dev/null +++ b/ci/travis/check_typing_good.py @@ -0,0 +1,27 @@ +import ray + + +ray.init() + + +@ray.remote +def f(a: int) -> str: + return "a = {}".format(a + 1) + + +@ray.remote +def g(s: str) -> str: + return s + " world" + + +@ray.remote +def h(a: str, b: int) -> str: + return a + + +print(f.remote(1)) +x = f.remote(1) +print(g.remote(x)) + +# typechecks but doesn't run +print(ray.get(f.remote(x))) diff --git a/ci/travis/install-dependencies.sh b/ci/travis/install-dependencies.sh index eed92b77b..4c613b1ee 100755 --- a/ci/travis/install-dependencies.sh +++ b/ci/travis/install-dependencies.sh @@ -239,7 +239,7 @@ install_dependencies() { opencv-python-headless pyyaml pandas==1.0.5 requests feather-format lxml openpyxl xlrd \ py-spy pytest pytest-timeout networkx tabulate aiohttp uvicorn dataclasses pygments werkzeug \ kubernetes flask grpcio pytest-sugar pytest-rerunfailures pytest-asyncio scikit-learn==0.22.2 numba \ - Pillow prometheus_client boto3 pettingzoo) + Pillow prometheus_client boto3 pettingzoo mypy) if [ "${OSTYPE}" != msys ]; then # These packages aren't Windows-compatible pip_packages+=(blist) # https://github.com/DanielStutzbach/blist/issues/81#issue-391460716 diff --git a/python/ray/_raylet.pyi b/python/ray/_raylet.pyi new file mode 100644 index 000000000..b5b5a403e --- /dev/null +++ b/python/ray/_raylet.pyi @@ -0,0 +1,9 @@ +from typing import Any, Awaitable + + +class ObjectRef(Awaitable[Any]): + pass + + +class ObjectID(Awaitable[Any]): + pass diff --git a/python/ray/autoscaler/commands.py b/python/ray/autoscaler/commands.py index 157fc82c1..522b23935 100644 --- a/python/ray/autoscaler/commands.py +++ b/python/ray/autoscaler/commands.py @@ -1,14 +1,15 @@ import copy import hashlib import json +import logging import os +import random +import sys import tempfile import time -import logging -import sys -import click -import random +from typing import Any, Dict, Optional +import click import yaml try: # py3 from shlex import quote @@ -84,9 +85,11 @@ def request_resources(num_cpus=None, bundles=None): r.publish(AUTOSCALER_RESOURCE_REQUEST_CHANNEL, json.dumps(bundles)) -def create_or_update_cluster(config_file, override_min_workers, - override_max_workers, no_restart, restart_only, - yes, override_cluster_name, no_config_cache): +def create_or_update_cluster( + config_file: str, override_min_workers: Optional[int], + override_max_workers: Optional[int], no_restart: bool, + restart_only: bool, yes: bool, override_cluster_name: Optional[str], + no_config_cache: bool) -> None: """Create or updates an autoscaling Ray cluster from a config json.""" config = yaml.safe_load(open(config_file).read()) if override_min_workers is not None: @@ -100,7 +103,8 @@ def create_or_update_cluster(config_file, override_min_workers, override_cluster_name) -def _bootstrap_config(config, no_config_cache=False): +def _bootstrap_config(config: Dict[str, Any], + no_config_cache: bool = False) -> Dict[str, Any]: config = prepare_config(config) hasher = hashlib.sha1() @@ -125,8 +129,9 @@ def _bootstrap_config(config, no_config_cache=False): return resolved_config -def teardown_cluster(config_file, yes, workers_only, override_cluster_name, - keep_min_workers): +def teardown_cluster(config_file: str, yes: bool, workers_only: bool, + override_cluster_name: Optional[str], + keep_min_workers: bool): """Destroys all nodes of a Ray cluster described by a config json.""" config = yaml.safe_load(open(config_file).read()) @@ -410,8 +415,9 @@ def get_or_create_head_node(config, config_file, no_restart, restart_only, yes, provider.cleanup() -def attach_cluster(config_file, start, use_screen, use_tmux, - override_cluster_name, new, port_forward): +def attach_cluster(config_file: str, start: bool, use_screen: bool, + use_tmux: bool, override_cluster_name: Optional[str], + new: bool, port_forward: Any): """Attaches to a screen for the specified cluster. Arguments: @@ -452,17 +458,17 @@ def attach_cluster(config_file, start, use_screen, use_tmux, port_forward=port_forward) -def exec_cluster(config_file, +def exec_cluster(config_file: str, *, - cmd=None, - run_env="auto", - screen=False, - tmux=False, - stop=False, - start=False, - override_cluster_name=None, - port_forward=None, - with_output=False): + cmd: Any = None, + run_env: str = "auto", + screen: bool = False, + tmux: bool = False, + stop: bool = False, + start: bool = False, + override_cluster_name: Optional[str] = None, + port_forward: Any = None, + with_output: bool = False): """Runs a command on the specified cluster. Arguments: @@ -571,12 +577,12 @@ def _exec(updater, run_env=run_env) -def rsync(config_file, - source, - target, - override_cluster_name, - down, - all_nodes=False): +def rsync(config_file: str, + source: Optional[str], + target: Optional[str], + override_cluster_name: Optional[str], + down: bool, + all_nodes: bool = False): """Rsyncs files. Arguments: @@ -639,7 +645,8 @@ def rsync(config_file, provider.cleanup() -def get_head_node_ip(config_file, override_cluster_name): +def get_head_node_ip(config_file: str, + override_cluster_name: Optional[str]) -> str: """Returns head node IP for given configuration file if exists.""" config = yaml.safe_load(open(config_file).read()) @@ -659,7 +666,8 @@ def get_head_node_ip(config_file, override_cluster_name): return head_node_ip -def get_worker_node_ips(config_file, override_cluster_name): +def get_worker_node_ips(config_file: str, + override_cluster_name: Optional[str]) -> str: """Returns worker node IPs for given configuration file.""" config = yaml.safe_load(open(config_file).read()) @@ -695,10 +703,10 @@ def _get_worker_nodes(config, override_cluster_name): provider.cleanup() -def _get_head_node(config, - config_file, - override_cluster_name, - create_if_needed=False): +def _get_head_node(config: Dict[str, Any], + config_file: str, + override_cluster_name: Optional[str], + create_if_needed: bool = False) -> str: provider = get_node_provider(config["provider"], config["cluster_name"]) try: head_node_tags = { diff --git a/python/ray/autoscaler/node_provider.py b/python/ray/autoscaler/node_provider.py index 7698a0bd4..a621b2e96 100644 --- a/python/ray/autoscaler/node_provider.py +++ b/python/ray/autoscaler/node_provider.py @@ -1,6 +1,8 @@ import importlib import logging import os +from typing import Any, Dict + import yaml from ray.autoscaler.command_runner import SSHCommandRunner, DockerCommandRunner @@ -102,7 +104,8 @@ def load_class(path): return getattr(module, class_str) -def get_node_provider(provider_config, cluster_name): +def get_node_provider(provider_config: Dict[str, Any], + cluster_name: str) -> Any: importer = NODE_PROVIDERS.get(provider_config["type"]) if importer is None: raise NotImplementedError("Unsupported node provider: {}".format( @@ -225,7 +228,7 @@ class NodeProvider: cluster_name, process_runner, use_internal_ip, - docker_config=None): + docker_config=None) -> Any: """ Returns the CommandRunner class used to perform SSH commands. Args: diff --git a/python/ray/autoscaler/util.py b/python/ray/autoscaler/util.py index 13ae5262f..47fedd959 100644 --- a/python/ray/autoscaler/util.py +++ b/python/ray/autoscaler/util.py @@ -1,9 +1,10 @@ import collections -import os -import json -import threading import hashlib +import json import jsonschema +import os +import threading +from typing import Any, Dict import ray import ray.services as services @@ -45,7 +46,7 @@ class ConcurrentCounter: return sum(self._counter.values()) -def validate_config(config): +def validate_config(config: Dict[str, Any]) -> None: """Required Dicts indicate that no extra fields can be introduced.""" if not isinstance(config, dict): raise ValueError("Config {} is not a dictionary".format(config)) @@ -65,7 +66,7 @@ def prepare_config(config): return with_defaults -def fillout_defaults(config): +def fillout_defaults(config: Dict[str, Any]) -> Dict[str, Any]: defaults = get_default_config(config["provider"]) defaults.update(config) defaults["auth"] = defaults.get("auth", {}) diff --git a/python/ray/py.typed b/python/ray/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/python/ray/util/named_actors.py b/python/ray/util/named_actors.py index ff84fd9f5..64f377d09 100644 --- a/python/ray/util/named_actors.py +++ b/python/ray/util/named_actors.py @@ -42,7 +42,7 @@ def _get_actor(name): return handle -def get_actor(name): +def get_actor(name: str) -> ray.actor.ActorHandle: """Get a named actor which was previously created. If the actor doesn't exist, an exception will be raised. diff --git a/python/ray/worker.pyi b/python/ray/worker.pyi new file mode 100644 index 000000000..1f5bbcbf9 --- /dev/null +++ b/python/ray/worker.pyi @@ -0,0 +1,73 @@ +from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload + +from ray._raylet import ObjectRef + + +T0 = TypeVar("T0") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") +T6 = TypeVar("T6") +T7 = TypeVar("T7") +T8 = TypeVar("T8") +T9 = TypeVar("T9") +R = TypeVar("R") + + +class RemoteFunction(Generic[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], Any]) -> None: pass + + @overload + def remote(self) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef]) -> ObjectRef: ... + @overload + def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef], arg9: Union[T9, ObjectRef]) -> ObjectRef: ... + def remote(self, *args, **kwargs) -> ObjectRef: + pass + + +@overload +def remote(function: Callable[[], R]) -> RemoteFunction[None, None, None, None, None, None, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0], R]) -> RemoteFunction[T0, None, None, None, None, None, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[T0, T1, None, None, None, None, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[T0, T1, T2, None, None, None, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[T0, T1, T2, T3, None, None, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[T0, T1, T2, T3, T4, None, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, None, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, None, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ... +@overload +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ... +# Pass on typing actors for now. The following makes it so no type errors are generated for actors. +@overload +def remote(t: type) -> Any: ... +def remote(function: Callable[..., R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: pass