diff --git a/.travis.yml b/.travis.yml index e3b39c41f..be0b7664f 100644 --- a/.travis.yml +++ b/.travis.yml @@ -356,10 +356,6 @@ script: # ray operator tests - (cd deploy/ray-operator && export CC=gcc && suppress_output go build && suppress_output go test ./...) - # test ray typing - - mypy --strict ./ci/travis/check_typing_good.py - - mypy --strict ./ci/travis/check_typing_bad.py && return 1 || return 0 - # bazel python tests. This should be run last to keep its logs at the end of travis logs. - if [ $RAY_CI_PYTHON_AFFECTED == "1" ]; then ./ci/keep_alive bazel test --config=ci $(./scripts/bazel_export_options) --test_tag_filters=-jenkins_only python/ray/tests/...; fi # NO MORE TESTS BELOW, keep them above. diff --git a/ci/travis/check_typing_good.py b/ci/travis/check_typing_good.py deleted file mode 100644 index 81128ff05..000000000 --- a/ci/travis/check_typing_good.py +++ /dev/null @@ -1,26 +0,0 @@ -import ray - -ray.init() - - -@ray.remote -def f(a: int) -> str: - return "a = {}".format(a + 1) - - -@ray.remote -def g(s: str) -> str: - return s + " world" - - -@ray.remote -def h(a: str, b: int) -> str: - return a - - -print(f.remote(1)) -x = f.remote(1) -print(g.remote(x)) - -# typechecks but doesn't run -print(ray.get(f.remote(x))) diff --git a/python/ray/_raylet.pyi b/python/ray/_raylet.pyi index b5b5a403e..691620b27 100644 --- a/python/ray/_raylet.pyi +++ b/python/ray/_raylet.pyi @@ -1,9 +1,11 @@ -from typing import Any, Awaitable +from typing import Any, Awaitable, TypeVar + +R = TypeVar("R") -class ObjectRef(Awaitable[Any]): +class ObjectRef(Awaitable[R]): pass -class ObjectID(Awaitable[Any]): +class ObjectID(Awaitable[R]): pass diff --git a/python/ray/tests/BUILD b/python/ray/tests/BUILD index c5c82c3d2..50f7be593 100644 --- a/python/ray/tests/BUILD +++ b/python/ray/tests/BUILD @@ -99,7 +99,7 @@ py_test_module_list( deps = ["//:ray_lib"], ) -# TODO(barakmich): aws/ might want its own buildfile, or +# TODO(barakmich): aws/ might want its own buildfile, or # py_test_module_list should support subdirectories. py_test( name = "test_autoscaler_aws", @@ -108,3 +108,12 @@ py_test( deps = ["//:ray_lib"], ) +# Note(simon): typing tests are not included in module list +# because they requires globs and it might be refactored in the future. +py_test( + name = "test_typing", + size = "small", + srcs = SRCS + ["test_typing.py"] + glob(["typing_files"]), + tags = ["exclusive"], + deps = ["//:ray_lib"], +) diff --git a/python/ray/tests/test_typing.py b/python/ray/tests/test_typing.py new file mode 100644 index 000000000..512a53cea --- /dev/null +++ b/python/ray/tests/test_typing.py @@ -0,0 +1,26 @@ +import sys +import os + +import mypy.api as mypy_api +import pytest + +TYPING_TEST_DIRS = os.path.join(os.path.dirname(__file__), "typing_files") + + +def test_typing_good(): + script = os.path.join(TYPING_TEST_DIRS, "check_typing_good.py") + msg, _, status_code = mypy_api.run([script]) + assert status_code == 0, msg + + +def test_typing_bad(): + script = os.path.join(TYPING_TEST_DIRS, "check_typing_bad.py") + msg, _, status_code = mypy_api.run([script]) + assert status_code == 1, msg + + +if __name__ == "__main__": + # Make subprocess happy in bazel. + os.environ["LC_ALL"] = "en_US.UTF-8" + os.environ["LANG"] = "en_US.UTF-8" + sys.exit(pytest.main(["-v", __file__])) diff --git a/ci/travis/check_typing_bad.py b/python/ray/tests/typing_files/check_typing_bad.py similarity index 57% rename from ci/travis/check_typing_bad.py rename to python/ray/tests/typing_files/check_typing_bad.py index 2ed95d033..0433e4800 100644 --- a/ci/travis/check_typing_bad.py +++ b/python/ray/tests/typing_files/check_typing_bad.py @@ -18,8 +18,16 @@ def h(a: str, b: int) -> str: return a -# Does not typecheck: +# Does not typecheck due to incorrect input type: a = h.remote(1, 1) b = f.remote("hello") c = f.remote(1, 1) d = f.remote(1) + 1 + +# Check return type +ref_to_str = f.remote(1) +unwrapped_str = ray.get(ref_to_str) +unwrapped_str + 100 # Fail + +# Check ObjectRef[T] as args +f.remote(ref_to_str) # Fail diff --git a/python/ray/tests/typing_files/check_typing_good.py b/python/ray/tests/typing_files/check_typing_good.py new file mode 100644 index 000000000..b81f8527d --- /dev/null +++ b/python/ray/tests/typing_files/check_typing_good.py @@ -0,0 +1,32 @@ +import ray + +ray.init() + + +@ray.remote +def f(a: int) -> str: + return "a = {}".format(a + 1) + + +@ray.remote +def g(s: str) -> str: + return s + " world" + + +@ray.remote +def h(a: str, b: int) -> str: + return a + + +# Make sure the function arg is check +print(f.remote(1)) +object_ref_str = f.remote(1) + +# Make sure the ObjectRef[T] variant of function arg is checked +print(g.remote(object_ref_str)) + +# Make sure there can be mixed T0 and ObjectRef[T1] for args +print(h.remote(object_ref_str, 100)) + +# Make sure the return type is checked. +xy = ray.get(object_ref_str) + "y" diff --git a/python/ray/worker.pyi b/python/ray/worker.pyi index 1f5bbcbf9..8954f1936 100644 --- a/python/ray/worker.pyi +++ b/python/ray/worker.pyi @@ -1,4 +1,5 @@ -from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload +# yapf: disable +from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload, Sequence, List from ray._raylet import ObjectRef @@ -16,58 +17,66 @@ T9 = TypeVar("T9") R = TypeVar("R") -class RemoteFunction(Generic[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]): - def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], Any]) -> None: pass +class RemoteFunction(Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]): + def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> None: pass @overload - def remote(self) -> ObjectRef: ... + def remote(self) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]], arg7: Union[T7, ObjectRef[T7]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef]) -> ObjectRef: ... + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]], arg7: Union[T7, ObjectRef[T7]], arg8: Union[T8, ObjectRef[T8]]) -> ObjectRef[R]: ... @overload - def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef], arg9: Union[T9, ObjectRef]) -> ObjectRef: ... - def remote(self, *args, **kwargs) -> ObjectRef: - pass + def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]], arg7: Union[T7, ObjectRef[T7]], arg8: Union[T8, ObjectRef[T8]], arg9: Union[T9, ObjectRef[T9]]) -> ObjectRef[R]: ... + def remote(self, *args, **kwargs) -> ObjectRef[R]: ... @overload -def remote(function: Callable[[], R]) -> RemoteFunction[None, None, None, None, None, None, None, None, None, None]: ... +def remote(function: Callable[[], R]) -> RemoteFunction[R, None, None, None, None, None, None, None, None, None, None]: ... @overload -def remote(function: Callable[[T0], R]) -> RemoteFunction[T0, None, None, None, None, None, None, None, None, None]: ... +def remote(function: Callable[[T0], R]) -> RemoteFunction[R, T0, None, None, None, None, None, None, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[T0, T1, None, None, None, None, None, None, None, None]: ... +def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[R, T0, T1, None, None, None, None, None, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[T0, T1, T2, None, None, None, None, None, None, None]: ... +def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[R, T0, T1, T2, None, None, None, None, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[T0, T1, T2, T3, None, None, None, None, None, None]: ... +def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[R, T0, T1, T2, T3, None, None, None, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[T0, T1, T2, T3, T4, None, None, None, None, None]: ... +def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, None, None, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, None, None, None, None]: ... +def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, None, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, None, None, None]: ... +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, None, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ... +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ... +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ... @overload -def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ... +def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ... # Pass on typing actors for now. The following makes it so no type errors are generated for actors. @overload def remote(t: type) -> Any: ... -def remote(function: Callable[..., R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: pass +def remote(function: Callable[..., R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ... + + + +@overload +def get(object_refs: Sequence[ObjectRef[Any]], timeout: Optional[float] = None) -> List[Any]: ... +@overload +def get(object_refs: Sequence[ObjectRef[R]], timeout: Optional[float] = None) -> List[R]: ... +@overload +def get(object_refs: ObjectRef[R], timeout: Optional[float] = None) -> R: ...