Improve local_mode (#5060)

This commit is contained in:
Edward Oakes
2019-07-08 02:10:50 +02:00
committed by Robert Nishihara
parent 932d6b2517
commit 8f53364097
7 changed files with 280 additions and 85 deletions
+78 -14
View File
@@ -1586,22 +1586,21 @@ def test_local_mode(shutdown_only):
return np.ones([3, 4, 5])
xref = f.remote()
# Remote functions should return by value.
assert np.alltrue(xref == np.ones([3, 4, 5]))
# Check that ray.get is the identity.
assert np.alltrue(xref == ray.get(xref))
# Remote functions should return ObjectIDs.
assert isinstance(xref, ray.ObjectID)
assert np.alltrue(ray.get(xref) == np.ones([3, 4, 5]))
y = np.random.normal(size=[11, 12])
# Check that ray.put is the identity.
assert np.alltrue(y == ray.put(y))
# Check that ray.get(ray.put) is the identity.
assert np.alltrue(y == ray.get(ray.put(y)))
# Make sure objects are immutable, this example is why we need to copy
# arguments before passing them into remote functions in python mode
aref = local_mode_f.remote()
assert np.alltrue(aref == np.array([0, 0]))
bref = local_mode_g.remote(aref)
assert np.alltrue(ray.get(aref) == np.array([0, 0]))
bref = local_mode_g.remote(ray.get(aref))
# Make sure local_mode_g does not mutate aref.
assert np.alltrue(aref == np.array([0, 0]))
assert np.alltrue(bref == np.array([1, 0]))
assert np.alltrue(ray.get(aref) == np.array([0, 0]))
assert np.alltrue(ray.get(bref) == np.array([1, 0]))
# wait should return the first num_returns values passed in as the
# first list and the remaining values as the second list
@@ -1612,6 +1611,25 @@ def test_local_mode(shutdown_only):
assert ready == object_ids[:num_returns]
assert remaining == object_ids[num_returns:]
# Check that ray.put() and ray.internal.free() work in local mode.
v1 = np.ones(10)
v2 = np.zeros(10)
k1 = ray.put(v1)
assert np.alltrue(v1 == ray.get(k1))
k2 = ray.put(v2)
assert np.alltrue(v2 == ray.get(k2))
ray.internal.free([k1, k2])
with pytest.raises(Exception):
ray.get(k1)
with pytest.raises(Exception):
ray.get(k2)
# Should fail silently.
ray.internal.free([k1, k2])
# Test actors in LOCAL_MODE.
@ray.remote
@@ -1629,9 +1647,14 @@ def test_local_mode(shutdown_only):
array[0] = -1
self.array = array
@ray.method(num_return_vals=3)
def returns_multiple(self):
return 1, 2, 3
test_actor = LocalModeTestClass.remote(np.arange(10))
# Remote actor functions should return by value
assert np.alltrue(test_actor.get_array.remote() == np.arange(10))
obj = test_actor.get_array.remote()
assert isinstance(obj, ray.ObjectID)
assert np.alltrue(ray.get(obj) == np.arange(10))
test_array = np.arange(10)
# Remote actor functions should not mutate arguments
@@ -1639,9 +1662,9 @@ def test_local_mode(shutdown_only):
assert np.alltrue(test_array == np.arange(10))
# Remote actor functions should keep state
test_array[0] = -1
assert np.alltrue(test_array == test_actor.get_array.remote())
assert np.alltrue(test_array == ray.get(test_actor.get_array.remote()))
# Check that actor handles work in Python mode.
# Check that actor handles work in local mode.
@ray.remote
def use_actor_handle(handle):
@@ -1651,6 +1674,47 @@ def test_local_mode(shutdown_only):
ray.get(use_actor_handle.remote(test_actor))
# Check that exceptions are deferred until ray.get().
exception_str = "test_basic remote task exception"
@ray.remote
def throws():
raise Exception(exception_str)
obj = throws.remote()
with pytest.raises(Exception, match=exception_str):
ray.get(obj)
# Check that multiple return values are handled properly.
@ray.remote(num_return_vals=3)
def returns_multiple():
return 1, 2, 3
obj1, obj2, obj3 = returns_multiple.remote()
assert ray.get(obj1) == 1
assert ray.get(obj2) == 2
assert ray.get(obj3) == 3
assert ray.get([obj1, obj2, obj3]) == [1, 2, 3]
obj1, obj2, obj3 = test_actor.returns_multiple.remote()
assert ray.get(obj1) == 1
assert ray.get(obj2) == 2
assert ray.get(obj3) == 3
assert ray.get([obj1, obj2, obj3]) == [1, 2, 3]
@ray.remote(num_return_vals=2)
def returns_multiple_throws():
raise Exception(exception_str)
obj1, obj2 = returns_multiple_throws.remote()
with pytest.raises(Exception, match=exception_str):
ray.get(obj)
ray.get(obj1)
with pytest.raises(Exception, match=exception_str):
ray.get(obj2)
def test_resource_constraints(shutdown_only):
num_workers = 20