mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 23:08:32 +08:00
[Core] Type check ObjectRef (#9856)
* Type check ObjectRef * Bug fix * Port typing tests to bazel test
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from typing import Any, Awaitable
|
||||
from typing import Any, Awaitable, TypeVar
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ObjectRef(Awaitable[Any]):
|
||||
class ObjectRef(Awaitable[R]):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectID(Awaitable[Any]):
|
||||
class ObjectID(Awaitable[R]):
|
||||
pass
|
||||
|
||||
+10
-1
@@ -99,7 +99,7 @@ py_test_module_list(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
# TODO(barakmich): aws/ might want its own buildfile, or
|
||||
# TODO(barakmich): aws/ might want its own buildfile, or
|
||||
# py_test_module_list should support subdirectories.
|
||||
py_test(
|
||||
name = "test_autoscaler_aws",
|
||||
@@ -108,3 +108,12 @@ py_test(
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
# Note(simon): typing tests are not included in module list
|
||||
# because they requires globs and it might be refactored in the future.
|
||||
py_test(
|
||||
name = "test_typing",
|
||||
size = "small",
|
||||
srcs = SRCS + ["test_typing.py"] + glob(["typing_files"]),
|
||||
tags = ["exclusive"],
|
||||
deps = ["//:ray_lib"],
|
||||
)
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
import mypy.api as mypy_api
|
||||
import pytest
|
||||
|
||||
TYPING_TEST_DIRS = os.path.join(os.path.dirname(__file__), "typing_files")
|
||||
|
||||
|
||||
def test_typing_good():
|
||||
script = os.path.join(TYPING_TEST_DIRS, "check_typing_good.py")
|
||||
msg, _, status_code = mypy_api.run([script])
|
||||
assert status_code == 0, msg
|
||||
|
||||
|
||||
def test_typing_bad():
|
||||
script = os.path.join(TYPING_TEST_DIRS, "check_typing_bad.py")
|
||||
msg, _, status_code = mypy_api.run([script])
|
||||
assert status_code == 1, msg
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Make subprocess happy in bazel.
|
||||
os.environ["LC_ALL"] = "en_US.UTF-8"
|
||||
os.environ["LANG"] = "en_US.UTF-8"
|
||||
sys.exit(pytest.main(["-v", __file__]))
|
||||
@@ -0,0 +1,33 @@
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
|
||||
@ray.remote
|
||||
def f(a: int) -> str:
|
||||
return "a = {}".format(a + 1)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def g(s: str) -> str:
|
||||
return s + " world"
|
||||
|
||||
|
||||
@ray.remote
|
||||
def h(a: str, b: int) -> str:
|
||||
return a
|
||||
|
||||
|
||||
# Does not typecheck due to incorrect input type:
|
||||
a = h.remote(1, 1)
|
||||
b = f.remote("hello")
|
||||
c = f.remote(1, 1)
|
||||
d = f.remote(1) + 1
|
||||
|
||||
# Check return type
|
||||
ref_to_str = f.remote(1)
|
||||
unwrapped_str = ray.get(ref_to_str)
|
||||
unwrapped_str + 100 # Fail
|
||||
|
||||
# Check ObjectRef[T] as args
|
||||
f.remote(ref_to_str) # Fail
|
||||
@@ -0,0 +1,32 @@
|
||||
import ray
|
||||
|
||||
ray.init()
|
||||
|
||||
|
||||
@ray.remote
|
||||
def f(a: int) -> str:
|
||||
return "a = {}".format(a + 1)
|
||||
|
||||
|
||||
@ray.remote
|
||||
def g(s: str) -> str:
|
||||
return s + " world"
|
||||
|
||||
|
||||
@ray.remote
|
||||
def h(a: str, b: int) -> str:
|
||||
return a
|
||||
|
||||
|
||||
# Make sure the function arg is check
|
||||
print(f.remote(1))
|
||||
object_ref_str = f.remote(1)
|
||||
|
||||
# Make sure the ObjectRef[T] variant of function arg is checked
|
||||
print(g.remote(object_ref_str))
|
||||
|
||||
# Make sure there can be mixed T0 and ObjectRef[T1] for args
|
||||
print(h.remote(object_ref_str, 100))
|
||||
|
||||
# Make sure the return type is checked.
|
||||
xy = ray.get(object_ref_str) + "y"
|
||||
+37
-28
@@ -1,4 +1,5 @@
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload
|
||||
# yapf: disable
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union, overload, Sequence, List
|
||||
|
||||
from ray._raylet import ObjectRef
|
||||
|
||||
@@ -16,58 +17,66 @@ T9 = TypeVar("T9")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class RemoteFunction(Generic[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
|
||||
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], Any]) -> None: pass
|
||||
class RemoteFunction(Generic[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]):
|
||||
def __init__(self, function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> None: pass
|
||||
|
||||
@overload
|
||||
def remote(self) -> ObjectRef: ...
|
||||
def remote(self) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]], arg7: Union[T7, ObjectRef[T7]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]], arg7: Union[T7, ObjectRef[T7]], arg8: Union[T8, ObjectRef[T8]]) -> ObjectRef[R]: ...
|
||||
@overload
|
||||
def remote(self, arg0: Union[T0, ObjectRef], arg1: Union[T1, ObjectRef], arg2: Union[T2, ObjectRef], arg3: Union[T3, ObjectRef], arg4: Union[T4, ObjectRef], arg5: Union[T5, ObjectRef], arg6: Union[T6, ObjectRef], arg7: Union[T7, ObjectRef], arg8: Union[T8, ObjectRef], arg9: Union[T9, ObjectRef]) -> ObjectRef: ...
|
||||
def remote(self, *args, **kwargs) -> ObjectRef:
|
||||
pass
|
||||
def remote(self, arg0: Union[T0, ObjectRef[T0]], arg1: Union[T1, ObjectRef[T1]], arg2: Union[T2, ObjectRef[T2]], arg3: Union[T3, ObjectRef[T3]], arg4: Union[T4, ObjectRef[T4]], arg5: Union[T5, ObjectRef[T5]], arg6: Union[T6, ObjectRef[T6]], arg7: Union[T7, ObjectRef[T7]], arg8: Union[T8, ObjectRef[T8]], arg9: Union[T9, ObjectRef[T9]]) -> ObjectRef[R]: ...
|
||||
def remote(self, *args, **kwargs) -> ObjectRef[R]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def remote(function: Callable[[], R]) -> RemoteFunction[None, None, None, None, None, None, None, None, None, None]: ...
|
||||
def remote(function: Callable[[], R]) -> RemoteFunction[R, None, None, None, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0], R]) -> RemoteFunction[T0, None, None, None, None, None, None, None, None, None]: ...
|
||||
def remote(function: Callable[[T0], R]) -> RemoteFunction[R, T0, None, None, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[T0, T1, None, None, None, None, None, None, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1], R]) -> RemoteFunction[R, T0, T1, None, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[T0, T1, T2, None, None, None, None, None, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2], R]) -> RemoteFunction[R, T0, T1, T2, None, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[T0, T1, T2, T3, None, None, None, None, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3], R]) -> RemoteFunction[R, T0, T1, T2, T3, None, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[T0, T1, T2, T3, T4, None, None, None, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, None, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, None, None, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, None, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, None, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, None, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, None, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, None]: ...
|
||||
@overload
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ...
|
||||
def remote(function: Callable[[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9], R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ...
|
||||
# Pass on typing actors for now. The following makes it so no type errors are generated for actors.
|
||||
@overload
|
||||
def remote(t: type) -> Any: ...
|
||||
def remote(function: Callable[..., R]) -> RemoteFunction[T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: pass
|
||||
def remote(function: Callable[..., R]) -> RemoteFunction[R, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9]: ...
|
||||
|
||||
|
||||
|
||||
@overload
|
||||
def get(object_refs: Sequence[ObjectRef[Any]], timeout: Optional[float] = None) -> List[Any]: ...
|
||||
@overload
|
||||
def get(object_refs: Sequence[ObjectRef[R]], timeout: Optional[float] = None) -> List[R]: ...
|
||||
@overload
|
||||
def get(object_refs: ObjectRef[R], timeout: Optional[float] = None) -> R: ...
|
||||
|
||||
Reference in New Issue
Block a user