mirror of
https://github.com/wassname/ray.git
synced 2026-07-02 10:46:13 +08:00
[tune] Add util function to broadcast objects (#1845)
* add util * Fri Apr 6 15:09:20 PDT 2018 * doc * Fri Apr 6 15:21:42 PDT 2018 * Fri Apr 6 15:28:07 PDT 2018 * Fri Apr 6 15:28:26 PDT 2018 * Update tune-config.rst * Update tune-config.rst
This commit is contained in:
@@ -13,6 +13,7 @@ from ray.tune import Trainable, TuneError
|
||||
from ray.tune import register_env, register_trainable, run_experiments
|
||||
from ray.tune.registry import _default_registry, TRAINABLE_CLASS
|
||||
from ray.tune.result import DEFAULT_RESULTS_DIR
|
||||
from ray.tune.util import pin_in_object_store, get_pinned_object
|
||||
from ray.tune.experiment import Experiment
|
||||
from ray.tune.trial import Trial, Resources
|
||||
from ray.tune.trial_runner import TrialRunner
|
||||
@@ -28,6 +29,15 @@ class TrainableFunctionApiTest(unittest.TestCase):
|
||||
ray.worker.cleanup()
|
||||
_register_all() # re-register the evicted objects
|
||||
|
||||
def testPinObject(self):
|
||||
X = pin_in_object_store("hello")
|
||||
|
||||
@ray.remote
|
||||
def f():
|
||||
return get_pinned_object(X)
|
||||
|
||||
self.assertEqual(ray.get(f.remote()), "hello")
|
||||
|
||||
def testRegisterEnv(self):
|
||||
register_env("foo", lambda: None)
|
||||
self.assertRaises(TypeError, lambda: register_env("foo", 2))
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import base64
|
||||
|
||||
import ray
|
||||
from ray.tune.registry import _to_pinnable, _from_pinnable
|
||||
|
||||
|
||||
_pinned_objects = []
|
||||
PINNED_OBJECT_PREFIX = "ray.tune.PinnedObject:"
|
||||
|
||||
|
||||
def pin_in_object_store(obj):
|
||||
obj_id = ray.put(_to_pinnable(obj))
|
||||
_pinned_objects.append(ray.get(obj_id))
|
||||
return "{}{}".format(
|
||||
PINNED_OBJECT_PREFIX, base64.b64encode(obj_id.id()).decode("utf-8"))
|
||||
|
||||
|
||||
def get_pinned_object(pinned_id):
|
||||
from ray.local_scheduler import ObjectID
|
||||
return _from_pinnable(ray.get(ObjectID(
|
||||
base64.b64decode(pinned_id[len(PINNED_OBJECT_PREFIX):]))))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ray.init()
|
||||
X = pin_in_object_store("hello")
|
||||
print(X)
|
||||
result = get_pinned_object(X)
|
||||
print(result)
|
||||
Reference in New Issue
Block a user