diff --git a/python/ray/tune/function_runner.py b/python/ray/tune/function_runner.py index eefc11240..b4fff276b 100644 --- a/python/ray/tune/function_runner.py +++ b/python/ray/tune/function_runner.py @@ -9,6 +9,7 @@ import traceback from ray.tune import TuneError from ray.tune.trainable import Trainable from ray.tune.result import TrainingResult +from ray.tune.util import _serve_get_pin_requests class StatusReporter(object): @@ -108,6 +109,7 @@ class FunctionRunner(Trainable): self._default_config["script_min_iter_time_s"])) result = self._status_reporter._get_and_clear_status() while result is None: + _serve_get_pin_requests() time.sleep(1) result = self._status_reporter._get_and_clear_status() if result.timesteps_total is None: diff --git a/python/ray/tune/test/trial_runner_test.py b/python/ray/tune/test/trial_runner_test.py index 0f20cc9e3..8aae21c92 100644 --- a/python/ray/tune/test/trial_runner_test.py +++ b/python/ray/tune/test/trial_runner_test.py @@ -38,6 +38,25 @@ class TrainableFunctionApiTest(unittest.TestCase): self.assertEqual(ray.get(f.remote()), "hello") + def testFetchPinned(self): + X = pin_in_object_store("hello") + + def train(config, reporter): + get_pinned_object(X) + reporter(timesteps_total=100, done=True) + + register_trainable("f1", train) + [trial] = run_experiments({ + "foo": { + "run": "f1", + "config": { + "script_min_iter_time_s": 0, + }, + } + }) + self.assertEqual(trial.status, Trial.TERMINATED) + self.assertEqual(trial.last_result.timesteps_total, 100) + def testRegisterEnv(self): register_env("foo", lambda: None) self.assertRaises(TypeError, lambda: register_env("foo", 2)) diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 041383d54..45718916a 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -3,15 +3,25 @@ from __future__ import division from __future__ import print_function import base64 +import queue +import threading import ray from ray.tune.registry import _to_pinnable, _from_pinnable _pinned_objects = [] +_fetch_requests = queue.Queue() PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:" def pin_in_object_store(obj): + """Pin an object in the object store. + + It will be available as long as the pinning process is alive. The pinned + object can be retrieved by calling get_pinned_object on the identifier + returned by this call. + """ + obj_id = ray.put(_to_pinnable(obj)) _pinned_objects.append(ray.get(obj_id)) return "{}{}".format(PINNED_OBJECT_PREFIX, @@ -19,12 +29,40 @@ def pin_in_object_store(obj): def get_pinned_object(pinned_id): + """Retrieve a pinned object from the object store.""" + from ray.local_scheduler import ObjectID + + if threading.current_thread().getName() != "MainThread": + placeholder = queue.Queue() + _fetch_requests.put((placeholder, pinned_id)) + print("Requesting main thread to fetch pinned object", pinned_id) + return placeholder.get() + return _from_pinnable( ray.get( ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):])))) +def _serve_get_pin_requests(): + """This is hack to avoid ray.get() on the function runner thread. + + The issue is that we run trainable functions on a separate thread, + which cannot access Ray API methods. So instead, that thread puts the + fetch in a queue that is periodically checked from the main thread. + """ + + assert threading.current_thread().getName() == "MainThread" + + try: + while not _fetch_requests.empty(): + (placeholder, pinned_id) = _fetch_requests.get_nowait() + print("Fetching pinned object from main thread", pinned_id) + placeholder.put(get_pinned_object(pinned_id)) + except queue.Empty: + pass + + if __name__ == '__main__': ray.init() X = pin_in_object_store("hello")