diff --git a/python/ray/tune/util.py b/python/ray/tune/util.py index 256411ace..e0ade8b12 100644 --- a/python/ray/tune/util.py +++ b/python/ray/tune/util.py @@ -200,18 +200,22 @@ def _from_pinnable(obj): return obj[0] -def validate_save_restore(trainable_cls, config=None, use_object_store=False): +def validate_save_restore(trainable_cls, + config=None, + num_gpus=0, + use_object_store=False): """Helper method to check if your Trainable class will resume correctly. Args: trainable_cls: Trainable class for evaluation. config (dict): Config to pass to Trainable when testing. + num_gpus (int): GPU resources to allocate when testing. use_object_store (bool): Whether to save and restore to Ray's object store. Recommended to set this to True if planning to use algorithms that pause training (i.e., PBT, HyperBand). """ assert ray.is_initialized(), "Need Ray to be initialized." - remote_cls = ray.remote(trainable_cls) + remote_cls = ray.remote(num_gpus=num_gpus)(trainable_cls) trainable_1 = remote_cls.remote(config=config) trainable_2 = remote_cls.remote(config=config)