mirror of
https://github.com/wassname/ray.git
synced 2026-06-28 00:29:38 +08:00
320 lines
9.3 KiB
Python
320 lines
9.3 KiB
Python
import pytest
|
|
import time
|
|
import sys
|
|
import logging
|
|
|
|
from ray.experimental.client.common import ClientObjectRef
|
|
from ray.experimental.client.ray_client_helpers import ray_start_client_server
|
|
|
|
|
|
def test_real_ray_fallback(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
def get_nodes_real():
|
|
import ray as real_ray
|
|
return real_ray.nodes()
|
|
|
|
nodes = ray.get(get_nodes_real.remote())
|
|
assert len(nodes) == 1, nodes
|
|
|
|
@ray.remote
|
|
def get_nodes():
|
|
# Can access the full Ray API in remote methods.
|
|
return ray.nodes()
|
|
|
|
nodes = ray.get(get_nodes.remote())
|
|
assert len(nodes) == 1, nodes
|
|
|
|
|
|
def test_nested_function(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
def g():
|
|
@ray.remote
|
|
def f():
|
|
return "OK"
|
|
|
|
return ray.get(f.remote())
|
|
|
|
assert ray.get(g.remote()) == "OK"
|
|
|
|
|
|
def test_put_get(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
objectref = ray.put("hello world")
|
|
print(objectref)
|
|
|
|
retval = ray.get(objectref)
|
|
assert retval == "hello world"
|
|
|
|
|
|
def test_wait(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
objectref = ray.put("hello world")
|
|
ready, remaining = ray.wait([objectref])
|
|
assert remaining == []
|
|
retval = ray.get(ready[0])
|
|
assert retval == "hello world"
|
|
|
|
objectref2 = ray.put(5)
|
|
ready, remaining = ray.wait([objectref, objectref2])
|
|
assert (ready, remaining) == ([objectref], [objectref2]) or \
|
|
(ready, remaining) == ([objectref2], [objectref])
|
|
ready_retval = ray.get(ready[0])
|
|
remaining_retval = ray.get(remaining[0])
|
|
assert (ready_retval, remaining_retval) == ("hello world", 5) \
|
|
or (ready_retval, remaining_retval) == (5, "hello world")
|
|
|
|
with pytest.raises(Exception):
|
|
# Reference not in the object store.
|
|
ray.wait([ClientObjectRef("blabla")])
|
|
with pytest.raises(TypeError):
|
|
ray.wait("blabla")
|
|
with pytest.raises(TypeError):
|
|
ray.wait(ClientObjectRef("blabla"))
|
|
with pytest.raises(TypeError):
|
|
ray.wait(["blabla"])
|
|
|
|
|
|
def test_remote_functions(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
def plus2(x):
|
|
return x + 2
|
|
|
|
@ray.remote
|
|
def fact(x):
|
|
print(x, type(fact))
|
|
if x <= 0:
|
|
return 1
|
|
# This hits the "nested tasks" issue
|
|
# https://github.com/ray-project/ray/issues/3644
|
|
# So we're on the right track!
|
|
return ray.get(fact.remote(x - 1)) * x
|
|
|
|
ref2 = plus2.remote(234)
|
|
# `236`
|
|
assert ray.get(ref2) == 236
|
|
|
|
ref3 = fact.remote(20)
|
|
# `2432902008176640000`
|
|
assert ray.get(ref3) == 2_432_902_008_176_640_000
|
|
|
|
# Reuse the cached ClientRemoteFunc object
|
|
ref4 = fact.remote(5)
|
|
assert ray.get(ref4) == 120
|
|
|
|
# Test ray.wait()
|
|
ref5 = fact.remote(10)
|
|
# should return ref2, ref3, ref4
|
|
res = ray.wait([ref5, ref2, ref3, ref4], num_returns=3)
|
|
assert [ref2, ref3, ref4] == res[0]
|
|
assert [ref5] == res[1]
|
|
assert ray.get(res[0]) == [236, 2_432_902_008_176_640_000, 120]
|
|
# should return ref2, ref3, ref4, ref5
|
|
res = ray.wait([ref2, ref3, ref4, ref5], num_returns=4)
|
|
assert [ref2, ref3, ref4, ref5] == res[0]
|
|
assert [] == res[1]
|
|
all_vals = ray.get(res[0])
|
|
assert all_vals == [236, 2_432_902_008_176_640_000, 120, 3628800]
|
|
|
|
|
|
def test_function_calling_function(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
def g():
|
|
return "OK"
|
|
|
|
@ray.remote
|
|
def f():
|
|
print(f, g)
|
|
return ray.get(g.remote())
|
|
|
|
print(f, type(f))
|
|
assert ray.get(f.remote()) == "OK"
|
|
|
|
|
|
def test_basic_actor(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
class HelloActor:
|
|
def __init__(self):
|
|
self.count = 0
|
|
|
|
def say_hello(self, whom):
|
|
self.count += 1
|
|
return ("Hello " + whom, self.count)
|
|
|
|
actor = HelloActor.remote()
|
|
s, count = ray.get(actor.say_hello.remote("you"))
|
|
assert s == "Hello you"
|
|
assert count == 1
|
|
s, count = ray.get(actor.say_hello.remote("world"))
|
|
assert s == "Hello world"
|
|
assert count == 2
|
|
|
|
|
|
def test_pass_handles(ray_start_regular_shared):
|
|
"""Test that passing client handles to actors and functions to remote actors
|
|
in functions (on the server or raylet side) works transparently to the
|
|
caller.
|
|
"""
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
class ExecActor:
|
|
def exec(self, f, x):
|
|
return ray.get(f.remote(x))
|
|
|
|
def exec_exec(self, actor, f, x):
|
|
return ray.get(actor.exec.remote(f, x))
|
|
|
|
@ray.remote
|
|
def fact(x):
|
|
out = 1
|
|
while x > 0:
|
|
out = out * x
|
|
x -= 1
|
|
return out
|
|
|
|
@ray.remote
|
|
def func_exec(f, x):
|
|
return ray.get(f.remote(x))
|
|
|
|
@ray.remote
|
|
def func_actor_exec(actor, f, x):
|
|
return ray.get(actor.exec.remote(f, x))
|
|
|
|
@ray.remote
|
|
def sneaky_func_exec(obj, x):
|
|
return ray.get(obj["f"].remote(x))
|
|
|
|
@ray.remote
|
|
def sneaky_actor_exec(obj, x):
|
|
return ray.get(obj["actor"].exec.remote(obj["f"], x))
|
|
|
|
def local_fact(x):
|
|
if x <= 0:
|
|
return 1
|
|
return x * local_fact(x - 1)
|
|
|
|
assert ray.get(fact.remote(7)) == local_fact(7)
|
|
assert ray.get(func_exec.remote(fact, 8)) == local_fact(8)
|
|
test_obj = {}
|
|
test_obj["f"] = fact
|
|
assert ray.get(sneaky_func_exec.remote(test_obj, 5)) == local_fact(5)
|
|
actor_handle = ExecActor.remote()
|
|
assert ray.get(actor_handle.exec.remote(fact, 7)) == local_fact(7)
|
|
assert ray.get(func_actor_exec.remote(actor_handle, fact,
|
|
10)) == local_fact(10)
|
|
second_actor = ExecActor.remote()
|
|
assert ray.get(actor_handle.exec_exec.remote(second_actor, fact,
|
|
9)) == local_fact(9)
|
|
test_actor_obj = {}
|
|
test_actor_obj["actor"] = second_actor
|
|
test_actor_obj["f"] = fact
|
|
assert ray.get(sneaky_actor_exec.remote(test_actor_obj,
|
|
4)) == local_fact(4)
|
|
|
|
|
|
def test_basic_log_stream(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
log_msgs = []
|
|
|
|
def test_log(level, msg):
|
|
log_msgs.append(msg)
|
|
|
|
ray.worker.log_client.log = test_log
|
|
ray.worker.log_client.set_logstream_level(logging.DEBUG)
|
|
# Allow some time to propogate
|
|
time.sleep(1)
|
|
x = ray.put("Foo")
|
|
assert ray.get(x) == "Foo"
|
|
time.sleep(1)
|
|
logs_with_id = [msg for msg in log_msgs if msg.find(x.id.hex()) >= 0]
|
|
assert len(logs_with_id) >= 2
|
|
assert any((msg.find("get") >= 0 for msg in logs_with_id))
|
|
assert any((msg.find("put") >= 0 for msg in logs_with_id))
|
|
|
|
|
|
def test_stdout_log_stream(ray_start_regular_shared):
|
|
with ray_start_client_server() as ray:
|
|
log_msgs = []
|
|
|
|
def test_log(level, msg):
|
|
log_msgs.append(msg)
|
|
|
|
ray.worker.log_client.stdstream = test_log
|
|
|
|
@ray.remote
|
|
def print_on_stderr_and_stdout(s):
|
|
print(s)
|
|
print(s, file=sys.stderr)
|
|
|
|
time.sleep(1)
|
|
print_on_stderr_and_stdout.remote("Hello world")
|
|
time.sleep(1)
|
|
assert len(log_msgs) == 2
|
|
assert all((msg.find("Hello world") for msg in log_msgs))
|
|
|
|
|
|
def test_create_remote_before_start(ray_start_regular_shared):
|
|
"""Creates remote objects (as though in a library) before
|
|
starting the client.
|
|
"""
|
|
from ray.experimental.client import ray
|
|
|
|
@ray.remote
|
|
class Returner:
|
|
def doit(self):
|
|
return "foo"
|
|
|
|
@ray.remote
|
|
def f(x):
|
|
return x + 20
|
|
|
|
# Prints in verbose tests
|
|
print("Created remote functions")
|
|
|
|
with ray_start_client_server() as ray:
|
|
assert ray.get(f.remote(3)) == 23
|
|
a = Returner.remote()
|
|
assert ray.get(a.doit.remote()) == "foo"
|
|
|
|
|
|
def test_basic_named_actor(ray_start_regular_shared):
|
|
"""Test that ray.get_actor() can create and return a detached actor.
|
|
"""
|
|
with ray_start_client_server() as ray:
|
|
|
|
@ray.remote
|
|
class Accumulator:
|
|
def __init__(self):
|
|
self.x = 0
|
|
|
|
def inc(self):
|
|
self.x += 1
|
|
|
|
def get(self):
|
|
return self.x
|
|
|
|
# Create the actor
|
|
actor = Accumulator.options(name="test_acc").remote()
|
|
|
|
actor.inc.remote()
|
|
actor.inc.remote()
|
|
del actor
|
|
|
|
new_actor = ray.get_actor("test_acc")
|
|
new_actor.inc.remote()
|
|
assert ray.get(new_actor.get.remote()) == 3
|
|
|
|
|
|
if __name__ == "__main__":
|
|
sys.exit(pytest.main(["-v", __file__]))
|