mirror of
https://github.com/wassname/ray.git
synced 2026-06-27 19:00:36 +08:00
clean up imports (#230)
This commit is contained in:
committed by
Philipp Moritz
parent
191909dd93
commit
5dd411546d
@@ -1,6 +1,5 @@
|
||||
import numpy as np
|
||||
import ray
|
||||
import ray.services as services
|
||||
import os
|
||||
|
||||
import functions
|
||||
@@ -11,7 +10,7 @@ epochs = 100
|
||||
|
||||
worker_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
worker_path = os.path.join(worker_dir, "worker.py")
|
||||
services.start_ray_local(num_workers=num_workers, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=num_workers, worker_path=worker_path)
|
||||
|
||||
best_params = None
|
||||
best_accuracy = 0
|
||||
|
||||
@@ -38,12 +38,12 @@ def train_cnn(params, epochs):
|
||||
sess.run(train_step, feed_dict={x: batch[0], y: batch[1], keep_prob: keep})
|
||||
if i % 100 == 0: # checks if accuracy is low enough to stop early every set number of epochs
|
||||
train_ac = accuracy.eval(feed_dict={x: batch[0], y: batch[1], keep_prob: 1.0})
|
||||
if train_ac < 0.25: # Accuracy threshold is on a application to application basis.
|
||||
if train_ac < 0.25: # Accuracy threshold is on a application to application basis.
|
||||
totalacc = accuracy.eval(feed_dict={x: mnist.validation.images, y: mnist.validation.labels, keep_prob: 1.0})
|
||||
return totalacc
|
||||
totalacc = accuracy.eval(feed_dict={x: mnist.validation.images, y: mnist.validation.labels, keep_prob: 1.0})
|
||||
return totalacc
|
||||
|
||||
return totalacc.astype("float64")
|
||||
|
||||
def cnn_setup(x, y, keep_prob, lr, stddev):
|
||||
first_hidden = 32
|
||||
second_hidden = 64
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import ray
|
||||
import ray.worker as worker
|
||||
|
||||
import functions
|
||||
|
||||
@@ -13,4 +12,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
ray.register_module(functions)
|
||||
worker.main_loop()
|
||||
ray.worker.main_loop()
|
||||
|
||||
@@ -3,20 +3,19 @@ import boto3
|
||||
import os
|
||||
import numpy as np
|
||||
import ray
|
||||
import ray.services as services
|
||||
import ray.datasets.imagenet as imagenet
|
||||
|
||||
import functions
|
||||
|
||||
parser = argparse.ArgumentParser(description="Parse information for data loading.")
|
||||
parser.add_argument("--s3-bucket", type=str, help="Name of the bucket that contains the image data.")
|
||||
parser.add_argument("--s3-bucket", type=str, required=True, help="Name of the bucket that contains the image data.")
|
||||
parser.add_argument("--key-prefix", default="ILSVRC2012_img_train/n015", type=str, help="Prefix for files to fetch.")
|
||||
parser.add_argument("--drop-ipython", default=False, type=bool, help="Drop into IPython at the end?")
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "worker.py")
|
||||
services.start_ray_local(num_workers=5, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=5, worker_path=worker_path)
|
||||
|
||||
s3 = boto3.resource("s3")
|
||||
imagenet_bucket = s3.Bucket(args.s3_bucket)
|
||||
|
||||
@@ -5,8 +5,6 @@ import numpy as np
|
||||
import ray.datasets.imagenet
|
||||
|
||||
import ray
|
||||
import ray.services as services
|
||||
import ray.worker as worker
|
||||
import ray.array.remote as ra
|
||||
import ray.array.distributed as da
|
||||
|
||||
@@ -19,7 +17,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
|
||||
ray.register_module(ray.datasets.imagenet)
|
||||
ray.register_module(functions)
|
||||
@@ -30,4 +28,4 @@ if __name__ == "__main__":
|
||||
ray.register_module(da.random)
|
||||
ray.register_module(da.linalg)
|
||||
|
||||
worker.main_loop()
|
||||
ray.worker.main_loop()
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import argparse
|
||||
|
||||
import ray
|
||||
import ray.worker as worker
|
||||
|
||||
import ray.array.remote as ra
|
||||
import ray.array.distributed as da
|
||||
@@ -15,7 +14,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
|
||||
ray.register_module(functions)
|
||||
|
||||
@@ -26,4 +25,4 @@ if __name__ == "__main__":
|
||||
ray.register_module(da.random)
|
||||
ray.register_module(da.linalg)
|
||||
|
||||
worker.main_loop()
|
||||
ray.worker.main_loop()
|
||||
|
||||
@@ -5,14 +5,13 @@ import numpy as np
|
||||
import cPickle as pickle
|
||||
import gym
|
||||
import ray
|
||||
import ray.services as services
|
||||
import os
|
||||
|
||||
import functions
|
||||
|
||||
worker_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
worker_path = os.path.join(worker_dir, "worker.py")
|
||||
services.start_ray_local(num_workers=10, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=10, worker_path=worker_path)
|
||||
|
||||
# hyperparameters
|
||||
H = 200 # number of hidden layer neurons
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import argparse
|
||||
import ray
|
||||
import ray.worker as worker
|
||||
import gym
|
||||
|
||||
import functions
|
||||
@@ -14,4 +13,4 @@ if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
ray.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
ray.register_module(functions)
|
||||
worker.main_loop()
|
||||
ray.worker.main_loop()
|
||||
|
||||
@@ -7,8 +7,6 @@ import ray.array.distributed as da
|
||||
import example_functions
|
||||
|
||||
import ray
|
||||
import ray.services as services
|
||||
import ray.worker as worker
|
||||
|
||||
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
|
||||
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
|
||||
@@ -17,7 +15,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
|
||||
ray.register_module(ra)
|
||||
ray.register_module(ra.random)
|
||||
@@ -27,4 +25,4 @@ if __name__ == "__main__":
|
||||
ray.register_module(da.linalg)
|
||||
ray.register_module(example_functions)
|
||||
|
||||
worker.main_loop()
|
||||
ray.worker.main_loop()
|
||||
|
||||
+4
-6
@@ -3,8 +3,6 @@ import argparse
|
||||
import numpy as np
|
||||
|
||||
import ray
|
||||
import ray.services as services
|
||||
import ray.worker as worker
|
||||
|
||||
import ray.array.remote as ra
|
||||
import ray.array.distributed as da
|
||||
@@ -26,11 +24,11 @@ if __name__ == "__main__":
|
||||
if args.attach:
|
||||
assert args.worker_path is None, "when attaching, no new worker can be started"
|
||||
assert args.num_workers is None, "when attaching, no new worker can be started"
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address, is_driver=True, mode=ray.SHELL_MODE)
|
||||
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address, is_driver=True, mode=ray.SHELL_MODE)
|
||||
else:
|
||||
services.start_ray_local(num_workers=args.num_workers if not args.num_workers is None else DEFAULT_NUM_WORKERS,
|
||||
worker_path=args.worker_path if not args.worker_path is None else DEFAULT_WORKER_PATH,
|
||||
driver_mode=ray.SHELL_MODE)
|
||||
ray.services.start_ray_local(num_workers=args.num_workers if not args.num_workers is None else DEFAULT_NUM_WORKERS,
|
||||
worker_path=args.worker_path if not args.worker_path is None else DEFAULT_WORKER_PATH,
|
||||
driver_mode=ray.SHELL_MODE)
|
||||
|
||||
import IPython
|
||||
IPython.embed()
|
||||
|
||||
+10
-13
@@ -1,8 +1,5 @@
|
||||
import unittest
|
||||
import ray
|
||||
import ray.serialization as serialization
|
||||
import ray.services as services
|
||||
import ray.worker as worker
|
||||
import numpy as np
|
||||
import time
|
||||
import subprocess32 as subprocess
|
||||
@@ -15,7 +12,7 @@ class RemoteArrayTest(unittest.TestCase):
|
||||
|
||||
def testMethods(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
|
||||
# test eye
|
||||
ref = ra.eye(3)
|
||||
@@ -42,25 +39,25 @@ class RemoteArrayTest(unittest.TestCase):
|
||||
val_r = ray.get(ref_r)
|
||||
self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
class DistributedArrayTest(unittest.TestCase):
|
||||
|
||||
def testSerialization(self):
|
||||
services.start_ray_local()
|
||||
ray.services.start_ray_local()
|
||||
|
||||
x = da.DistArray()
|
||||
x.construct([2, 3, 4], np.array([[[ray.put(0)]]]))
|
||||
capsule, _ = serialization.serialize(ray.worker.global_worker.handle, x)
|
||||
y = serialization.deserialize(ray.worker.global_worker.handle, capsule)
|
||||
capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x)
|
||||
y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule)
|
||||
self.assertEqual(x.shape, y.shape)
|
||||
self.assertEqual(x.objrefs[0, 0, 0].val, y.objrefs[0, 0, 0].val)
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testAssemble(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
|
||||
a = ra.ones([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
b = ra.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])
|
||||
@@ -68,11 +65,11 @@ class DistributedArrayTest(unittest.TestCase):
|
||||
x.construct([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
|
||||
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testMethods(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_services_local(num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path)
|
||||
ray.services.start_services_local(num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path)
|
||||
|
||||
x = da.zeros([9, 25, 51], "float")
|
||||
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == np.zeros([9, 25, 51])))
|
||||
@@ -206,7 +203,7 @@ class DistributedArrayTest(unittest.TestCase):
|
||||
d2 = np.random.randint(1, 35)
|
||||
test_dist_qr(d1, d2)
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -2,21 +2,20 @@ import os
|
||||
import tarfile
|
||||
import unittest
|
||||
import ray
|
||||
import ray.services as services
|
||||
import ray.datasets.imagenet as imagenet
|
||||
|
||||
class ImageNetTest(unittest.TestCase):
|
||||
|
||||
def testImageNetLoading(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=5, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=5, worker_path=worker_path)
|
||||
|
||||
chunk_name = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/mini.tar")
|
||||
tar = tarfile.open(chunk_name, mode= "r")
|
||||
chunk = imagenet.load_chunk(tar, size=(256, 256))
|
||||
self.assertEqual(chunk.shape, (2, 256, 256, 3))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -3,11 +3,9 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import ray
|
||||
import ray.worker
|
||||
import ray.services as services
|
||||
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
|
||||
d = {"w": np.zeros(1000000)}
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import unittest
|
||||
import ray
|
||||
import ray.services as services
|
||||
import time
|
||||
import os
|
||||
import numpy as np
|
||||
@@ -11,7 +10,7 @@ class MicroBenchmarkTest(unittest.TestCase):
|
||||
|
||||
def testTiming(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
|
||||
# measure the time required to submit a remote task to the scheduler
|
||||
elapsed_times = []
|
||||
@@ -78,7 +77,7 @@ class MicroBenchmarkTest(unittest.TestCase):
|
||||
print " worst: {}".format(elapsed_times[999])
|
||||
self.assertTrue(average_elapsed_time < 0.002) # should take 0.00087
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+45
-48
@@ -1,8 +1,5 @@
|
||||
import unittest
|
||||
import ray
|
||||
import ray.serialization as serialization
|
||||
import ray.services as services
|
||||
import ray.worker as worker
|
||||
import numpy as np
|
||||
import time
|
||||
import subprocess32 as subprocess
|
||||
@@ -34,35 +31,35 @@ class UserDefinedType(object):
|
||||
class SerializationTest(unittest.TestCase):
|
||||
|
||||
def roundTripTest(self, data):
|
||||
serialized, _ = serialization.serialize(ray.worker.global_worker.handle, data)
|
||||
result = serialization.deserialize(ray.worker.global_worker.handle, serialized)
|
||||
serialized, _ = ray.serialization.serialize(ray.worker.global_worker.handle, data)
|
||||
result = ray.serialization.deserialize(ray.worker.global_worker.handle, serialized)
|
||||
self.assertEqual(data, result)
|
||||
|
||||
def numpyTypeTest(self, typ):
|
||||
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
|
||||
b, _ = serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
c = serialization.deserialize(ray.worker.global_worker.handle, b)
|
||||
b, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
c = ray.serialization.deserialize(ray.worker.global_worker.handle, b)
|
||||
self.assertTrue((a == c).all())
|
||||
|
||||
a = np.array(0).astype(typ)
|
||||
b, _ = serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
c = serialization.deserialize(ray.worker.global_worker.handle, b)
|
||||
b, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
c = ray.serialization.deserialize(ray.worker.global_worker.handle, b)
|
||||
self.assertTrue((a == c).all())
|
||||
|
||||
a = np.empty((0,)).astype(typ)
|
||||
b, _ = serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
c = serialization.deserialize(ray.worker.global_worker.handle, b)
|
||||
b, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
c = ray.serialization.deserialize(ray.worker.global_worker.handle, b)
|
||||
self.assertTrue(a.dtype == c.dtype)
|
||||
|
||||
def testSerialize(self):
|
||||
services.start_ray_local()
|
||||
ray.services.start_ray_local()
|
||||
|
||||
for val in RAY_TEST_OBJECTS:
|
||||
self.roundTripTest(val)
|
||||
|
||||
a = np.zeros((100, 100))
|
||||
res, _ = serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
b = serialization.deserialize(ray.worker.global_worker.handle, res)
|
||||
res, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
b = ray.serialization.deserialize(ray.worker.global_worker.handle, res)
|
||||
self.assertTrue((a == b).all())
|
||||
|
||||
self.numpyTypeTest("int8")
|
||||
@@ -80,8 +77,8 @@ class SerializationTest(unittest.TestCase):
|
||||
ref3 = ray.put(0)
|
||||
|
||||
a = np.array([[ref0, ref1], [ref2, ref3]])
|
||||
capsule, _ = serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
result = serialization.deserialize(ray.worker.global_worker.handle, capsule)
|
||||
capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
|
||||
result = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule)
|
||||
self.assertTrue((a == result).all())
|
||||
|
||||
self.roundTripTest(ref0)
|
||||
@@ -89,13 +86,13 @@ class SerializationTest(unittest.TestCase):
|
||||
self.roundTripTest({"0": ref0, "1": ref1, "2": ref2, "3": ref3})
|
||||
self.roundTripTest((ref0, 1))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
# Test setting up object stores, transfering data between them and retrieving data to a client
|
||||
def testObjStore(self):
|
||||
[w1, w2] = services.start_services_local(return_drivers=True, num_objstores=2, num_workers_per_objstore=0)
|
||||
[w1, w2] = ray.services.start_services_local(return_drivers=True, num_objstores=2, num_workers_per_objstore=0)
|
||||
|
||||
# putting and getting an object shouldn't change it
|
||||
for data in ["h", "h" * 10000, 0, 0.0]:
|
||||
@@ -105,14 +102,14 @@ class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
# putting an object, shipping it to another worker, and getting it shouldn't change it
|
||||
for data in ["h", "h" * 10000, 0, 0.0, [1, 2, 3, "a", (1, 2)], ("a", ("b", 3))]:
|
||||
objref = worker.put(data, w1)
|
||||
result = worker.get(objref, w2)
|
||||
objref = ray.put(data, w1)
|
||||
result = ray.get(objref, w2)
|
||||
self.assertEqual(result, data)
|
||||
|
||||
# putting an array, shipping it to another worker, and getting it shouldn't change it
|
||||
for data in [np.zeros([10, 20]), np.random.normal(size=[45, 25])]:
|
||||
objref = worker.put(data, w1)
|
||||
result = worker.get(objref, w2)
|
||||
objref = ray.put(data, w1)
|
||||
result = ray.get(objref, w2)
|
||||
self.assertTrue(np.alltrue(result == data))
|
||||
|
||||
"""
|
||||
@@ -127,24 +124,24 @@ class ObjStoreTest(unittest.TestCase):
|
||||
|
||||
# shipping a numpy array inside something else should be fine
|
||||
data = ("a", np.random.normal(size=[10, 10]))
|
||||
objref = worker.put(data, w1)
|
||||
result = worker.get(objref, w2)
|
||||
objref = ray.put(data, w1)
|
||||
result = ray.get(objref, w2)
|
||||
self.assertTrue(data[0] == result[0])
|
||||
self.assertTrue(np.alltrue(data[1] == result[1]))
|
||||
|
||||
# shipping a numpy array inside something else should be fine
|
||||
data = ["a", np.random.normal(size=[10, 10])]
|
||||
objref = worker.put(data, w1)
|
||||
result = worker.get(objref, w2)
|
||||
objref = ray.put(data, w1)
|
||||
result = ray.get(objref, w2)
|
||||
self.assertTrue(data[0] == result[0])
|
||||
self.assertTrue(np.alltrue(data[1] == result[1]))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
class WorkerTest(unittest.TestCase):
|
||||
|
||||
def testPutGet(self):
|
||||
services.start_ray_local()
|
||||
ray.services.start_ray_local()
|
||||
|
||||
for i in range(100):
|
||||
value_before = i * 10 ** 6
|
||||
@@ -170,13 +167,13 @@ class WorkerTest(unittest.TestCase):
|
||||
value_after = ray.get(objref)
|
||||
self.assertEqual(value_before, value_after)
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
class APITest(unittest.TestCase):
|
||||
|
||||
def testObjRefAliasing(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
|
||||
ref = test_functions.test_alias_f()
|
||||
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
|
||||
@@ -185,11 +182,11 @@ class APITest(unittest.TestCase):
|
||||
ref = test_functions.test_alias_h()
|
||||
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testKeywordArgs(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
|
||||
x = test_functions.keyword_fct1(1)
|
||||
self.assertEqual(ray.get(x), "1 hello")
|
||||
@@ -222,11 +219,11 @@ class APITest(unittest.TestCase):
|
||||
x = test_functions.keyword_fct3(0, 1)
|
||||
self.assertEqual(ray.get(x), "0 1 hello world")
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testVariableNumberOfArgs(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
|
||||
|
||||
x = test_functions.varargs_fct1(0, 1, 2)
|
||||
self.assertEqual(ray.get(x), "0 1 2")
|
||||
@@ -236,11 +233,11 @@ class APITest(unittest.TestCase):
|
||||
self.assertTrue(test_functions.kwargs_exception_thrown)
|
||||
self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown)
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testNoArgs(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
|
||||
|
||||
test_functions.no_op()
|
||||
time.sleep(0.2)
|
||||
@@ -257,11 +254,11 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(task_info["num_succeeded"], 1)
|
||||
self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message"))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testTypeChecking(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
|
||||
ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
|
||||
|
||||
# Make sure that these functions throw exceptions because there return
|
||||
# values do not type check.
|
||||
@@ -273,12 +270,12 @@ class APITest(unittest.TestCase):
|
||||
self.assertEqual(len(task_info["running_tasks"]), 0)
|
||||
self.assertEqual(task_info["num_succeeded"], 0)
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
class TaskStatusTest(unittest.TestCase):
|
||||
def testFailedTask(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=3, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
|
||||
ray.services.start_ray_local(num_workers=3, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
|
||||
test_functions.test_alias_f()
|
||||
test_functions.throw_exception_fct1()
|
||||
test_functions.throw_exception_fct1()
|
||||
@@ -324,7 +321,7 @@ class ReferenceCountingTest(unittest.TestCase):
|
||||
|
||||
def testDeallocation(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
|
||||
x = test_functions.test_alias_f()
|
||||
ray.get(x)
|
||||
@@ -370,11 +367,11 @@ class ReferenceCountingTest(unittest.TestCase):
|
||||
time.sleep(0.1)
|
||||
self.assertTrue(ray.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, -1, -1])
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
def testGet(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
|
||||
for val in RAY_TEST_OBJECTS + [np.zeros((2, 2)), UserDefinedType()]:
|
||||
objref_val = check_get_deallocated(val)
|
||||
@@ -395,12 +392,12 @@ class ReferenceCountingTest(unittest.TestCase):
|
||||
self.assertTrue(np.alltrue(result == data))
|
||||
"""
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
@unittest.expectedFailure
|
||||
def testGetFailing(self):
|
||||
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
|
||||
services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
|
||||
|
||||
# This is failing, because for bool and None, we cannot track python
|
||||
# refcounts and therefore cannot keep the refcount up
|
||||
@@ -412,12 +409,12 @@ class ReferenceCountingTest(unittest.TestCase):
|
||||
x, objref_val = check_get_not_deallocated(val)
|
||||
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
class PythonModeTest(unittest.TestCase):
|
||||
|
||||
def testObjRefAliasing(self):
|
||||
services.start_ray_local(driver_mode=ray.PYTHON_MODE)
|
||||
ray.services.start_ray_local(driver_mode=ray.PYTHON_MODE)
|
||||
|
||||
xref = test_functions.test_alias_h()
|
||||
self.assertTrue(np.alltrue(xref == np.ones([3, 4, 5]))) # remote functions should return by value
|
||||
@@ -433,7 +430,7 @@ class PythonModeTest(unittest.TestCase):
|
||||
self.assertTrue(np.alltrue(aref == np.array([0, 0]))) # python_mode_g should not mutate aref
|
||||
self.assertTrue(np.alltrue(bref == np.array([1, 0])))
|
||||
|
||||
services.cleanup()
|
||||
ray.services.cleanup()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
+2
-4
@@ -8,8 +8,6 @@ import ray.array.distributed as da
|
||||
import ray.datasets.imagenet
|
||||
|
||||
import ray
|
||||
import ray.services as services
|
||||
import ray.worker as worker
|
||||
|
||||
parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
|
||||
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
|
||||
@@ -18,7 +16,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parser.parse_args()
|
||||
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
|
||||
|
||||
ray.register_module(test_functions)
|
||||
ray.register_module(ra)
|
||||
@@ -29,4 +27,4 @@ if __name__ == "__main__":
|
||||
ray.register_module(da.linalg)
|
||||
ray.register_module(sys.modules[__name__])
|
||||
|
||||
worker.main_loop()
|
||||
ray.worker.main_loop()
|
||||
|
||||
Reference in New Issue
Block a user