clean up imports (#230)

This commit is contained in:
Robert Nishihara
2016-07-08 12:46:47 -07:00
committed by Philipp Moritz
parent 191909dd93
commit 5dd411546d
16 changed files with 81 additions and 105 deletions
+1 -2
View File
@@ -1,6 +1,5 @@
import numpy as np
import ray
import ray.services as services
import os
import functions
@@ -11,7 +10,7 @@ epochs = 100
worker_dir = os.path.dirname(os.path.abspath(__file__))
worker_path = os.path.join(worker_dir, "worker.py")
services.start_ray_local(num_workers=num_workers, worker_path=worker_path)
ray.services.start_ray_local(num_workers=num_workers, worker_path=worker_path)
best_params = None
best_accuracy = 0
+3 -3
View File
@@ -38,12 +38,12 @@ def train_cnn(params, epochs):
sess.run(train_step, feed_dict={x: batch[0], y: batch[1], keep_prob: keep})
if i % 100 == 0: # checks if accuracy is low enough to stop early every set number of epochs
train_ac = accuracy.eval(feed_dict={x: batch[0], y: batch[1], keep_prob: 1.0})
if train_ac < 0.25: # Accuracy threshold is on a application to application basis.
if train_ac < 0.25: # Accuracy threshold is on a application to application basis.
totalacc = accuracy.eval(feed_dict={x: mnist.validation.images, y: mnist.validation.labels, keep_prob: 1.0})
return totalacc
totalacc = accuracy.eval(feed_dict={x: mnist.validation.images, y: mnist.validation.labels, keep_prob: 1.0})
return totalacc
return totalacc.astype("float64")
def cnn_setup(x, y, keep_prob, lr, stddev):
first_hidden = 32
second_hidden = 64
+1 -2
View File
@@ -1,6 +1,5 @@
import argparse
import ray
import ray.worker as worker
import functions
@@ -13,4 +12,4 @@ if __name__ == "__main__":
args = parser.parse_args()
ray.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.register_module(functions)
worker.main_loop()
ray.worker.main_loop()
+2 -3
View File
@@ -3,20 +3,19 @@ import boto3
import os
import numpy as np
import ray
import ray.services as services
import ray.datasets.imagenet as imagenet
import functions
parser = argparse.ArgumentParser(description="Parse information for data loading.")
parser.add_argument("--s3-bucket", type=str, help="Name of the bucket that contains the image data.")
parser.add_argument("--s3-bucket", type=str, required=True, help="Name of the bucket that contains the image data.")
parser.add_argument("--key-prefix", default="ILSVRC2012_img_train/n015", type=str, help="Prefix for files to fetch.")
parser.add_argument("--drop-ipython", default=False, type=bool, help="Drop into IPython at the end?")
if __name__ == "__main__":
args = parser.parse_args()
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "worker.py")
services.start_ray_local(num_workers=5, worker_path=worker_path)
ray.services.start_ray_local(num_workers=5, worker_path=worker_path)
s3 = boto3.resource("s3")
imagenet_bucket = s3.Bucket(args.s3_bucket)
+2 -4
View File
@@ -5,8 +5,6 @@ import numpy as np
import ray.datasets.imagenet
import ray
import ray.services as services
import ray.worker as worker
import ray.array.remote as ra
import ray.array.distributed as da
@@ -19,7 +17,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
if __name__ == "__main__":
args = parser.parse_args()
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.register_module(ray.datasets.imagenet)
ray.register_module(functions)
@@ -30,4 +28,4 @@ if __name__ == "__main__":
ray.register_module(da.random)
ray.register_module(da.linalg)
worker.main_loop()
ray.worker.main_loop()
+2 -3
View File
@@ -1,7 +1,6 @@
import argparse
import ray
import ray.worker as worker
import ray.array.remote as ra
import ray.array.distributed as da
@@ -15,7 +14,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
if __name__ == "__main__":
args = parser.parse_args()
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.register_module(functions)
@@ -26,4 +25,4 @@ if __name__ == "__main__":
ray.register_module(da.random)
ray.register_module(da.linalg)
worker.main_loop()
ray.worker.main_loop()
+1 -2
View File
@@ -5,14 +5,13 @@ import numpy as np
import cPickle as pickle
import gym
import ray
import ray.services as services
import os
import functions
worker_dir = os.path.dirname(os.path.abspath(__file__))
worker_path = os.path.join(worker_dir, "worker.py")
services.start_ray_local(num_workers=10, worker_path=worker_path)
ray.services.start_ray_local(num_workers=10, worker_path=worker_path)
# hyperparameters
H = 200 # number of hidden layer neurons
+1 -2
View File
@@ -1,6 +1,5 @@
import argparse
import ray
import ray.worker as worker
import gym
import functions
@@ -14,4 +13,4 @@ if __name__ == "__main__":
args = parser.parse_args()
ray.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.register_module(functions)
worker.main_loop()
ray.worker.main_loop()
+2 -4
View File
@@ -7,8 +7,6 @@ import ray.array.distributed as da
import example_functions
import ray
import ray.services as services
import ray.worker as worker
parser = argparse.ArgumentParser(description="Parse addresses for the worker to connect to.")
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
@@ -17,7 +15,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
if __name__ == "__main__":
args = parser.parse_args()
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.register_module(ra)
ray.register_module(ra.random)
@@ -27,4 +25,4 @@ if __name__ == "__main__":
ray.register_module(da.linalg)
ray.register_module(example_functions)
worker.main_loop()
ray.worker.main_loop()
+4 -6
View File
@@ -3,8 +3,6 @@ import argparse
import numpy as np
import ray
import ray.services as services
import ray.worker as worker
import ray.array.remote as ra
import ray.array.distributed as da
@@ -26,11 +24,11 @@ if __name__ == "__main__":
if args.attach:
assert args.worker_path is None, "when attaching, no new worker can be started"
assert args.num_workers is None, "when attaching, no new worker can be started"
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address, is_driver=True, mode=ray.SHELL_MODE)
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address, is_driver=True, mode=ray.SHELL_MODE)
else:
services.start_ray_local(num_workers=args.num_workers if not args.num_workers is None else DEFAULT_NUM_WORKERS,
worker_path=args.worker_path if not args.worker_path is None else DEFAULT_WORKER_PATH,
driver_mode=ray.SHELL_MODE)
ray.services.start_ray_local(num_workers=args.num_workers if not args.num_workers is None else DEFAULT_NUM_WORKERS,
worker_path=args.worker_path if not args.worker_path is None else DEFAULT_WORKER_PATH,
driver_mode=ray.SHELL_MODE)
import IPython
IPython.embed()
+10 -13
View File
@@ -1,8 +1,5 @@
import unittest
import ray
import ray.serialization as serialization
import ray.services as services
import ray.worker as worker
import numpy as np
import time
import subprocess32 as subprocess
@@ -15,7 +12,7 @@ class RemoteArrayTest(unittest.TestCase):
def testMethods(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
# test eye
ref = ra.eye(3)
@@ -42,25 +39,25 @@ class RemoteArrayTest(unittest.TestCase):
val_r = ray.get(ref_r)
self.assertTrue(np.allclose(np.dot(val_q, val_r), val_a))
services.cleanup()
ray.services.cleanup()
class DistributedArrayTest(unittest.TestCase):
def testSerialization(self):
services.start_ray_local()
ray.services.start_ray_local()
x = da.DistArray()
x.construct([2, 3, 4], np.array([[[ray.put(0)]]]))
capsule, _ = serialization.serialize(ray.worker.global_worker.handle, x)
y = serialization.deserialize(ray.worker.global_worker.handle, capsule)
capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, x)
y = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule)
self.assertEqual(x.shape, y.shape)
self.assertEqual(x.objrefs[0, 0, 0].val, y.objrefs[0, 0, 0].val)
services.cleanup()
ray.services.cleanup()
def testAssemble(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
a = ra.ones([da.BLOCK_SIZE, da.BLOCK_SIZE])
b = ra.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])
@@ -68,11 +65,11 @@ class DistributedArrayTest(unittest.TestCase):
x.construct([2 * da.BLOCK_SIZE, da.BLOCK_SIZE], np.array([[a], [b]]))
self.assertTrue(np.alltrue(x.assemble() == np.vstack([np.ones([da.BLOCK_SIZE, da.BLOCK_SIZE]), np.zeros([da.BLOCK_SIZE, da.BLOCK_SIZE])])))
services.cleanup()
ray.services.cleanup()
def testMethods(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_services_local(num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path)
ray.services.start_services_local(num_objstores=2, num_workers_per_objstore=5, worker_path=worker_path)
x = da.zeros([9, 25, 51], "float")
self.assertTrue(np.alltrue(ray.get(da.assemble(x)) == np.zeros([9, 25, 51])))
@@ -206,7 +203,7 @@ class DistributedArrayTest(unittest.TestCase):
d2 = np.random.randint(1, 35)
test_dist_qr(d1, d2)
services.cleanup()
ray.services.cleanup()
if __name__ == "__main__":
unittest.main()
+2 -3
View File
@@ -2,21 +2,20 @@ import os
import tarfile
import unittest
import ray
import ray.services as services
import ray.datasets.imagenet as imagenet
class ImageNetTest(unittest.TestCase):
def testImageNetLoading(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=5, worker_path=worker_path)
ray.services.start_ray_local(num_workers=5, worker_path=worker_path)
chunk_name = os.path.join(os.path.dirname(os.path.abspath(__file__)), "../data/mini.tar")
tar = tarfile.open(chunk_name, mode= "r")
chunk = imagenet.load_chunk(tar, size=(256, 256))
self.assertEqual(chunk.shape, (2, 256, 256, 3))
services.cleanup()
ray.services.cleanup()
if __name__ == "__main__":
unittest.main()
+1 -3
View File
@@ -3,11 +3,9 @@
import os
import numpy as np
import ray
import ray.worker
import ray.services as services
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
d = {"w": np.zeros(1000000)}
+2 -3
View File
@@ -1,6 +1,5 @@
import unittest
import ray
import ray.services as services
import time
import os
import numpy as np
@@ -11,7 +10,7 @@ class MicroBenchmarkTest(unittest.TestCase):
def testTiming(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=3, worker_path=worker_path)
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
# measure the time required to submit a remote task to the scheduler
elapsed_times = []
@@ -78,7 +77,7 @@ class MicroBenchmarkTest(unittest.TestCase):
print " worst: {}".format(elapsed_times[999])
self.assertTrue(average_elapsed_time < 0.002) # should take 0.00087
services.cleanup()
ray.services.cleanup()
if __name__ == "__main__":
unittest.main()
+45 -48
View File
@@ -1,8 +1,5 @@
import unittest
import ray
import ray.serialization as serialization
import ray.services as services
import ray.worker as worker
import numpy as np
import time
import subprocess32 as subprocess
@@ -34,35 +31,35 @@ class UserDefinedType(object):
class SerializationTest(unittest.TestCase):
def roundTripTest(self, data):
serialized, _ = serialization.serialize(ray.worker.global_worker.handle, data)
result = serialization.deserialize(ray.worker.global_worker.handle, serialized)
serialized, _ = ray.serialization.serialize(ray.worker.global_worker.handle, data)
result = ray.serialization.deserialize(ray.worker.global_worker.handle, serialized)
self.assertEqual(data, result)
def numpyTypeTest(self, typ):
a = np.random.randint(0, 10, size=(100, 100)).astype(typ)
b, _ = serialization.serialize(ray.worker.global_worker.handle, a)
c = serialization.deserialize(ray.worker.global_worker.handle, b)
b, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
c = ray.serialization.deserialize(ray.worker.global_worker.handle, b)
self.assertTrue((a == c).all())
a = np.array(0).astype(typ)
b, _ = serialization.serialize(ray.worker.global_worker.handle, a)
c = serialization.deserialize(ray.worker.global_worker.handle, b)
b, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
c = ray.serialization.deserialize(ray.worker.global_worker.handle, b)
self.assertTrue((a == c).all())
a = np.empty((0,)).astype(typ)
b, _ = serialization.serialize(ray.worker.global_worker.handle, a)
c = serialization.deserialize(ray.worker.global_worker.handle, b)
b, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
c = ray.serialization.deserialize(ray.worker.global_worker.handle, b)
self.assertTrue(a.dtype == c.dtype)
def testSerialize(self):
services.start_ray_local()
ray.services.start_ray_local()
for val in RAY_TEST_OBJECTS:
self.roundTripTest(val)
a = np.zeros((100, 100))
res, _ = serialization.serialize(ray.worker.global_worker.handle, a)
b = serialization.deserialize(ray.worker.global_worker.handle, res)
res, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
b = ray.serialization.deserialize(ray.worker.global_worker.handle, res)
self.assertTrue((a == b).all())
self.numpyTypeTest("int8")
@@ -80,8 +77,8 @@ class SerializationTest(unittest.TestCase):
ref3 = ray.put(0)
a = np.array([[ref0, ref1], [ref2, ref3]])
capsule, _ = serialization.serialize(ray.worker.global_worker.handle, a)
result = serialization.deserialize(ray.worker.global_worker.handle, capsule)
capsule, _ = ray.serialization.serialize(ray.worker.global_worker.handle, a)
result = ray.serialization.deserialize(ray.worker.global_worker.handle, capsule)
self.assertTrue((a == result).all())
self.roundTripTest(ref0)
@@ -89,13 +86,13 @@ class SerializationTest(unittest.TestCase):
self.roundTripTest({"0": ref0, "1": ref1, "2": ref2, "3": ref3})
self.roundTripTest((ref0, 1))
services.cleanup()
ray.services.cleanup()
class ObjStoreTest(unittest.TestCase):
# Test setting up object stores, transfering data between them and retrieving data to a client
def testObjStore(self):
[w1, w2] = services.start_services_local(return_drivers=True, num_objstores=2, num_workers_per_objstore=0)
[w1, w2] = ray.services.start_services_local(return_drivers=True, num_objstores=2, num_workers_per_objstore=0)
# putting and getting an object shouldn't change it
for data in ["h", "h" * 10000, 0, 0.0]:
@@ -105,14 +102,14 @@ class ObjStoreTest(unittest.TestCase):
# putting an object, shipping it to another worker, and getting it shouldn't change it
for data in ["h", "h" * 10000, 0, 0.0, [1, 2, 3, "a", (1, 2)], ("a", ("b", 3))]:
objref = worker.put(data, w1)
result = worker.get(objref, w2)
objref = ray.put(data, w1)
result = ray.get(objref, w2)
self.assertEqual(result, data)
# putting an array, shipping it to another worker, and getting it shouldn't change it
for data in [np.zeros([10, 20]), np.random.normal(size=[45, 25])]:
objref = worker.put(data, w1)
result = worker.get(objref, w2)
objref = ray.put(data, w1)
result = ray.get(objref, w2)
self.assertTrue(np.alltrue(result == data))
"""
@@ -127,24 +124,24 @@ class ObjStoreTest(unittest.TestCase):
# shipping a numpy array inside something else should be fine
data = ("a", np.random.normal(size=[10, 10]))
objref = worker.put(data, w1)
result = worker.get(objref, w2)
objref = ray.put(data, w1)
result = ray.get(objref, w2)
self.assertTrue(data[0] == result[0])
self.assertTrue(np.alltrue(data[1] == result[1]))
# shipping a numpy array inside something else should be fine
data = ["a", np.random.normal(size=[10, 10])]
objref = worker.put(data, w1)
result = worker.get(objref, w2)
objref = ray.put(data, w1)
result = ray.get(objref, w2)
self.assertTrue(data[0] == result[0])
self.assertTrue(np.alltrue(data[1] == result[1]))
services.cleanup()
ray.services.cleanup()
class WorkerTest(unittest.TestCase):
def testPutGet(self):
services.start_ray_local()
ray.services.start_ray_local()
for i in range(100):
value_before = i * 10 ** 6
@@ -170,13 +167,13 @@ class WorkerTest(unittest.TestCase):
value_after = ray.get(objref)
self.assertEqual(value_before, value_after)
services.cleanup()
ray.services.cleanup()
class APITest(unittest.TestCase):
def testObjRefAliasing(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=3, worker_path=worker_path)
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
ref = test_functions.test_alias_f()
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
@@ -185,11 +182,11 @@ class APITest(unittest.TestCase):
ref = test_functions.test_alias_h()
self.assertTrue(np.alltrue(ray.get(ref) == np.ones([3, 4, 5])))
services.cleanup()
ray.services.cleanup()
def testKeywordArgs(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
x = test_functions.keyword_fct1(1)
self.assertEqual(ray.get(x), "1 hello")
@@ -222,11 +219,11 @@ class APITest(unittest.TestCase):
x = test_functions.keyword_fct3(0, 1)
self.assertEqual(ray.get(x), "0 1 hello world")
services.cleanup()
ray.services.cleanup()
def testVariableNumberOfArgs(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path)
x = test_functions.varargs_fct1(0, 1, 2)
self.assertEqual(ray.get(x), "0 1 2")
@@ -236,11 +233,11 @@ class APITest(unittest.TestCase):
self.assertTrue(test_functions.kwargs_exception_thrown)
self.assertTrue(test_functions.varargs_and_kwargs_exception_thrown)
services.cleanup()
ray.services.cleanup()
def testNoArgs(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
test_functions.no_op()
time.sleep(0.2)
@@ -257,11 +254,11 @@ class APITest(unittest.TestCase):
self.assertEqual(task_info["num_succeeded"], 1)
self.assertTrue("The @remote decorator for function test_functions.no_op_fail has 0 return values, but test_functions.no_op_fail returned more than 0 values." in task_info["failed_tasks"][0].get("error_message"))
services.cleanup()
ray.services.cleanup()
def testTypeChecking(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
ray.services.start_ray_local(num_workers=1, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
# Make sure that these functions throw exceptions because there return
# values do not type check.
@@ -273,12 +270,12 @@ class APITest(unittest.TestCase):
self.assertEqual(len(task_info["running_tasks"]), 0)
self.assertEqual(task_info["num_succeeded"], 0)
services.cleanup()
ray.services.cleanup()
class TaskStatusTest(unittest.TestCase):
def testFailedTask(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=3, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
ray.services.start_ray_local(num_workers=3, worker_path=worker_path, driver_mode=ray.WORKER_MODE)
test_functions.test_alias_f()
test_functions.throw_exception_fct1()
test_functions.throw_exception_fct1()
@@ -324,7 +321,7 @@ class ReferenceCountingTest(unittest.TestCase):
def testDeallocation(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=3, worker_path=worker_path)
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
x = test_functions.test_alias_f()
ray.get(x)
@@ -370,11 +367,11 @@ class ReferenceCountingTest(unittest.TestCase):
time.sleep(0.1)
self.assertTrue(ray.scheduler_info()["reference_counts"][objref_val:(objref_val + 3)] == [-1, -1, -1])
services.cleanup()
ray.services.cleanup()
def testGet(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=3, worker_path=worker_path)
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
for val in RAY_TEST_OBJECTS + [np.zeros((2, 2)), UserDefinedType()]:
objref_val = check_get_deallocated(val)
@@ -395,12 +392,12 @@ class ReferenceCountingTest(unittest.TestCase):
self.assertTrue(np.alltrue(result == data))
"""
services.cleanup()
ray.services.cleanup()
@unittest.expectedFailure
def testGetFailing(self):
worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "test_worker.py")
services.start_ray_local(num_workers=3, worker_path=worker_path)
ray.services.start_ray_local(num_workers=3, worker_path=worker_path)
# This is failing, because for bool and None, we cannot track python
# refcounts and therefore cannot keep the refcount up
@@ -412,12 +409,12 @@ class ReferenceCountingTest(unittest.TestCase):
x, objref_val = check_get_not_deallocated(val)
self.assertEqual(ray.scheduler_info()["reference_counts"][objref_val], 1)
services.cleanup()
ray.services.cleanup()
class PythonModeTest(unittest.TestCase):
def testObjRefAliasing(self):
services.start_ray_local(driver_mode=ray.PYTHON_MODE)
ray.services.start_ray_local(driver_mode=ray.PYTHON_MODE)
xref = test_functions.test_alias_h()
self.assertTrue(np.alltrue(xref == np.ones([3, 4, 5]))) # remote functions should return by value
@@ -433,7 +430,7 @@ class PythonModeTest(unittest.TestCase):
self.assertTrue(np.alltrue(aref == np.array([0, 0]))) # python_mode_g should not mutate aref
self.assertTrue(np.alltrue(bref == np.array([1, 0])))
services.cleanup()
ray.services.cleanup()
if __name__ == "__main__":
unittest.main()
+2 -4
View File
@@ -8,8 +8,6 @@ import ray.array.distributed as da
import ray.datasets.imagenet
import ray
import ray.services as services
import ray.worker as worker
parser = argparse.ArgumentParser(description='Parse addresses for the worker to connect to.')
parser.add_argument("--scheduler-address", default="127.0.0.1:10001", type=str, help="the scheduler's address")
@@ -18,7 +16,7 @@ parser.add_argument("--worker-address", default="127.0.0.1:40001", type=str, hel
if __name__ == "__main__":
args = parser.parse_args()
worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.worker.connect(args.scheduler_address, args.objstore_address, args.worker_address)
ray.register_module(test_functions)
ray.register_module(ra)
@@ -29,4 +27,4 @@ if __name__ == "__main__":
ray.register_module(da.linalg)
ray.register_module(sys.modules[__name__])
worker.main_loop()
ray.worker.main_loop()