[tune] Allow fetching pinned objects from trainable functions (#1895)

* updates

* lint

* Update util.py

* Update function_runner.py

* updates
This commit is contained in:
Eric Liang
2018-04-16 15:54:38 -07:00
committed by GitHub
parent ddfc875149
commit ed8c0f1a38
3 changed files with 59 additions and 0 deletions
+2
View File
@@ -9,6 +9,7 @@ import traceback
from ray.tune import TuneError
from ray.tune.trainable import Trainable
from ray.tune.result import TrainingResult
from ray.tune.util import _serve_get_pin_requests
class StatusReporter(object):
@@ -108,6 +109,7 @@ class FunctionRunner(Trainable):
self._default_config["script_min_iter_time_s"]))
result = self._status_reporter._get_and_clear_status()
while result is None:
_serve_get_pin_requests()
time.sleep(1)
result = self._status_reporter._get_and_clear_status()
if result.timesteps_total is None:
+19
View File
@@ -38,6 +38,25 @@ class TrainableFunctionApiTest(unittest.TestCase):
self.assertEqual(ray.get(f.remote()), "hello")
def testFetchPinned(self):
X = pin_in_object_store("hello")
def train(config, reporter):
get_pinned_object(X)
reporter(timesteps_total=100, done=True)
register_trainable("f1", train)
[trial] = run_experiments({
"foo": {
"run": "f1",
"config": {
"script_min_iter_time_s": 0,
},
}
})
self.assertEqual(trial.status, Trial.TERMINATED)
self.assertEqual(trial.last_result.timesteps_total, 100)
def testRegisterEnv(self):
register_env("foo", lambda: None)
self.assertRaises(TypeError, lambda: register_env("foo", 2))
+38
View File
@@ -3,15 +3,25 @@ from __future__ import division
from __future__ import print_function
import base64
import queue
import threading
import ray
from ray.tune.registry import _to_pinnable, _from_pinnable
_pinned_objects = []
_fetch_requests = queue.Queue()
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
def pin_in_object_store(obj):
"""Pin an object in the object store.
It will be available as long as the pinning process is alive. The pinned
object can be retrieved by calling get_pinned_object on the identifier
returned by this call.
"""
obj_id = ray.put(_to_pinnable(obj))
_pinned_objects.append(ray.get(obj_id))
return "{}{}".format(PINNED_OBJECT_PREFIX,
@@ -19,12 +29,40 @@ def pin_in_object_store(obj):
def get_pinned_object(pinned_id):
"""Retrieve a pinned object from the object store."""
from ray.local_scheduler import ObjectID
if threading.current_thread().getName() != "MainThread":
placeholder = queue.Queue()
_fetch_requests.put((placeholder, pinned_id))
print("Requesting main thread to fetch pinned object", pinned_id)
return placeholder.get()
return _from_pinnable(
ray.get(
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
def _serve_get_pin_requests():
"""This is hack to avoid ray.get() on the function runner thread.
The issue is that we run trainable functions on a separate thread,
which cannot access Ray API methods. So instead, that thread puts the
fetch in a queue that is periodically checked from the main thread.
"""
assert threading.current_thread().getName() == "MainThread"
try:
while not _fetch_requests.empty():
(placeholder, pinned_id) = _fetch_requests.get_nowait()
print("Fetching pinned object from main thread", pinned_id)
placeholder.put(get_pinned_object(pinned_id))
except queue.Empty:
pass
if __name__ == '__main__':
ray.init()
X = pin_in_object_store("hello")