mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 07:23:55 +08:00
[tune] Allow fetching pinned objects from trainable functions (#1895)
* updates * lint * Update util.py * Update function_runner.py * updates
This commit is contained in:
@@ -9,6 +9,7 @@ import traceback
|
||||
from ray.tune import TuneError
|
||||
from ray.tune.trainable import Trainable
|
||||
from ray.tune.result import TrainingResult
|
||||
from ray.tune.util import _serve_get_pin_requests
|
||||
|
||||
|
||||
class StatusReporter(object):
|
||||
@@ -108,6 +109,7 @@ class FunctionRunner(Trainable):
|
||||
self._default_config["script_min_iter_time_s"]))
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
while result is None:
|
||||
_serve_get_pin_requests()
|
||||
time.sleep(1)
|
||||
result = self._status_reporter._get_and_clear_status()
|
||||
if result.timesteps_total is None:
|
||||
|
||||
@@ -38,6 +38,25 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(ray.get(f.remote()), "hello")
|
||||
|
||||
def testFetchPinned(self):
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
def train(config, reporter):
|
||||
get_pinned_object(X)
|
||||
reporter(timesteps_total=100, done=True)
|
||||
|
||||
register_trainable("f1", train)
|
||||
[trial] = run_experiments({
|
||||
"foo": {
|
||||
"run": "f1",
|
||||
"config": {
|
||||
"script_min_iter_time_s": 0,
|
||||
},
|
||||
}
|
||||
})
|
||||
self.assertEqual(trial.status, Trial.TERMINATED)
|
||||
self.assertEqual(trial.last_result.timesteps_total, 100)
|
||||
|
||||
def testRegisterEnv(self):
|
||||
register_env("foo", lambda: None)
|
||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||
|
||||
@@ -3,15 +3,25 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import base64
|
||||
import queue
|
||||
import threading
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import _to_pinnable, _from_pinnable
|
||||
|
||||
_pinned_objects = []
|
||||
_fetch_requests = queue.Queue()
|
||||
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
|
||||
|
||||
|
||||
def pin_in_object_store(obj):
|
||||
"""Pin an object in the object store.
|
||||
|
||||
It will be available as long as the pinning process is alive. The pinned
|
||||
object can be retrieved by calling get_pinned_object on the identifier
|
||||
returned by this call.
|
||||
"""
|
||||
|
||||
obj_id = ray.put(_to_pinnable(obj))
|
||||
_pinned_objects.append(ray.get(obj_id))
|
||||
return "{}{}".format(PINNED_OBJECT_PREFIX,
|
||||
@@ -19,12 +29,40 @@ def pin_in_object_store(obj):
|
||||
|
||||
|
||||
def get_pinned_object(pinned_id):
|
||||
"""Retrieve a pinned object from the object store."""
|
||||
|
||||
from ray.local_scheduler import ObjectID
|
||||
|
||||
if threading.current_thread().getName() != "MainThread":
|
||||
placeholder = queue.Queue()
|
||||
_fetch_requests.put((placeholder, pinned_id))
|
||||
print("Requesting main thread to fetch pinned object", pinned_id)
|
||||
return placeholder.get()
|
||||
|
||||
return _from_pinnable(
|
||||
ray.get(
|
||||
ObjectID(base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
|
||||
|
||||
def _serve_get_pin_requests():
|
||||
"""This is hack to avoid ray.get() on the function runner thread.
|
||||
|
||||
The issue is that we run trainable functions on a separate thread,
|
||||
which cannot access Ray API methods. So instead, that thread puts the
|
||||
fetch in a queue that is periodically checked from the main thread.
|
||||
"""
|
||||
|
||||
assert threading.current_thread().getName() == "MainThread"
|
||||
|
||||
try:
|
||||
while not _fetch_requests.empty():
|
||||
(placeholder, pinned_id) = _fetch_requests.get_nowait()
|
||||
print("Fetching pinned object from main thread", pinned_id)
|
||||
placeholder.put(get_pinned_object(pinned_id))
|
||||
except queue.Empty:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
Reference in New Issue
Block a user