Add numba test (#7298) (#7487)

This commit is contained in:
Landcold7
2020-03-08 03:12:25 +08:00
committed by GitHub
parent 115468de2c
commit beb9b02dbd
3 changed files with 47 additions and 2 deletions
+8
View File
@@ -383,3 +383,11 @@ py_test(
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
py_test(
name = "test_numba",
size = "small",
srcs = ["test_numba.py"],
tags = ["exclusive"],
deps = ["//:ray_lib"],
)
+37
View File
@@ -0,0 +1,37 @@
import unittest
from numba import njit
import numpy as np
import ray
@njit(fastmath=True)
def centroid(x, y):
return ((x / x.sum()) * y).sum()
# Define a wrapper to call centroid function
@ray.remote
def centroid_wrapper(x, y):
return centroid(x, y)
class NumbaTest(unittest.TestCase):
def setUp(self):
ray.init(num_cpus=1)
def tearDown(self):
ray.shutdown()
def test_numba_njit(self):
x = np.random.random(10)
y = np.random.random(1)
result = ray.get(centroid_wrapper.remote(x, y))
assert result == centroid(x, y)
if __name__ == "__main__":
import pytest
import sys
sys.exit(pytest.main(["-v", __file__]))