From 18241f4a2d2445943435fc5251104ea242a8d1f9 Mon Sep 17 00:00:00 2001 From: visatish Date: Mon, 4 Nov 2019 13:24:46 -0800 Subject: [PATCH] =?UTF-8?q?[tune]=20Added=20resources=5Fper=5Ftrial=20arg?= =?UTF-8?q?=20to=20validate=5Fsave=5Frestore=20u=E2=80=A6=20(#6032)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- python/ray/tune/util.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 256411ace..e0ade8b12 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -200,18 +200,22 @@ def _from_pinnable(obj): return obj[0] -def validate_save_restore(trainable_cls, config=None, use_object_store=False): +def validate_save_restore(trainable_cls, + config=None, + num_gpus=0, + use_object_store=False): """Helper method to check if your Trainable class will resume correctly. Args: trainable_cls: Trainable class for evaluation. config (dict): Config to pass to Trainable when testing. + num_gpus (int): GPU resources to allocate when testing. use_object_store (bool): Whether to save and restore to Ray's object store. Recommended to set this to True if planning to use algorithms that pause training (i.e., PBT, HyperBand). """ assert ray.is_initialized(), "Need Ray to be initialized." - remote_cls = ray.remote(trainable_cls) + remote_cls = ray.remote(num_gpus=num_gpus)(trainable_cls) trainable_1 = remote_cls.remote(config=config) trainable_2 = remote_cls.remote(config=config)