From d97bce0d64e1c1764e6fc0ca06c64682d57ccf9e Mon Sep 17 00:00:00 2001 From: Robert Nishihara Date: Sun, 26 Jun 2016 13:43:54 -0700 Subject: [PATCH] allow driver to run in PYTHON_MODE, which is equivalent to serial Python --- lib/python/ray/__init__.py | 1 + lib/python/ray/services.py | 3 ++- lib/python/ray/worker.py | 10 ++++++++++ test/runtest.py | 21 +++++++++++++++++++++ test/test_functions.py | 10 ++++++++++ 5 files changed, 44 insertions(+), 1 deletion(-) diff --git a/lib/python/ray/__init__.py b/lib/python/ray/__init__.py index 5a43889bd..3fdecd489 100644 --- a/lib/python/ray/__init__.py +++ b/lib/python/ray/__init__.py @@ -4,6 +4,7 @@ SCRIPT_MODE = 0 WORKER_MODE = 1 SHELL_MODE = 2 +PYTHON_MODE = 2 import libraylib as lib import serialization diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index 264ed63c2..9b621fc91 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -105,7 +105,8 @@ def start_node(scheduler_address, node_ip_address, num_workers, worker_path=None time.sleep(0.5) # driver_mode should equal ray.SCRIPT_MODE if this is being run in a script and -# ray.SHELL_MODE if it is being used interactively in a shell. +# ray.SHELL_MODE if it is being used interactively in a shell. It can also equal +# ray.PYTHON_MODE to run things in a manner equivalent to serial Python code. def start_singlenode_cluster(return_drivers=False, num_objstores=1, num_workers_per_objstore=0, worker_path=None, driver_mode=ray.SCRIPT_MODE): global drivers if num_workers_per_objstore > 0 and worker_path is None: diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index db3522d58..4099d712a 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -7,6 +7,7 @@ import typing import funcsigs import numpy as np import colorama +import copy import ray from ray.config import LOG_DIRECTORY, LOG_TIMESTAMP @@ -170,6 +171,8 @@ def disconnect(worker=global_worker): ray.lib.disconnect(worker.handle) def get(objref, worker=global_worker): + if worker.mode == ray.PYTHON_MODE: + return objref # In ray.PYTHON_MODE, ray.get is the identity operation (the input will actually be a value not an objref) ray.lib.request_object(worker.handle, objref) if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE: print_task_info(ray.lib.task_info(worker.handle), worker.mode) @@ -179,6 +182,8 @@ def get(objref, worker=global_worker): return value def put(value, worker=global_worker): + if worker.mode == ray.PYTHON_MODE: + return value # In ray.PYTHON_MODE, ray.put is the identity operation objref = ray.lib.get_objref(worker.handle) worker.put_object(objref, value) if worker.mode == ray.SHELL_MODE or worker.mode == ray.SCRIPT_MODE: @@ -225,6 +230,11 @@ def remote(arg_types, return_types, worker=global_worker): """This is what gets run immediately when a worker calls a remote function.""" args = list(args) args.extend([kwargs[keyword] if kwargs.has_key(keyword) else default for keyword, default in func_call.keyword_defaults[len(args):]]) # fill in the remaining arguments + if worker.mode == ray.PYTHON_MODE: + # In ray.PYTHON_MODE, remote calls simply execute the function. We copy + # the arguments to prevent the function call from mutating them and to + # match the usual behavior of immutable remote objects. + return func(*copy.deepcopy(args)) check_arguments(func_call, args) # throws an exception if args are invalid objrefs = worker.submit_task(func_call.func_name, args) if len(objrefs) == 1: diff --git a/test/runtest.py b/test/runtest.py index a0994840d..e35d17bfc 100644 --- a/test/runtest.py +++ b/test/runtest.py @@ -377,5 +377,26 @@ class ReferenceCountingTest(unittest.TestCase): services.cleanup() +class PythonModeTest(unittest.TestCase): + + def testObjRefAliasing(self): + services.start_singlenode_cluster(driver_mode=ray.PYTHON_MODE) + + xref = test_functions.test_alias_h() + self.assertTrue(np.alltrue(xref == np.ones([3, 4, 5]))) # remote functions should return by value + self.assertTrue(np.alltrue(xref == ray.get(xref))) # ray.get should be the identity + y = np.random.normal(size=[11, 12]) + self.assertTrue(np.alltrue(y == ray.put(y))) # ray.put should be the identity + + # make sure objects are immutable, this example is why we need to copy + # arguments before passing them into remote functions in python mode + aref = test_functions.python_mode_f() + self.assertTrue(np.alltrue(aref == np.array([0, 0]))) + bref = test_functions.python_mode_g(aref) + self.assertTrue(np.alltrue(aref == np.array([0, 0]))) # python_mode_g should not mutate aref + self.assertTrue(np.alltrue(bref == np.array([1, 0]))) + + services.cleanup() + if __name__ == "__main__": unittest.main() diff --git a/test/test_functions.py b/test/test_functions.py index 7e75671af..dfe8a6b2a 100644 --- a/test/test_functions.py +++ b/test/test_functions.py @@ -85,3 +85,13 @@ def throw_exception_fct2(): @ray.remote([float], [int, str, np.ndarray]) def throw_exception_fct3(x): raise Exception("Test function 3 intentionally failed.") + +# test Python mode +@ray.remote([], [np.ndarray]) +def python_mode_f(): + return np.array([0, 0]) + +@ray.remote([np.ndarray], [np.ndarray]) +def python_mode_g(x): + x[0] = 1 + return x