mirror of
https://github.com/wassname/ray.git
synced 2026-06-29 13:32:36 +08:00
[tune] Added resources_per_trial arg to validate_save_restore u… (#6032)
This commit is contained in:
@@ -200,18 +200,22 @@ def _from_pinnable(obj):
|
||||
return obj[0]
|
||||
|
||||
|
||||
def validate_save_restore(trainable_cls, config=None, use_object_store=False):
|
||||
def validate_save_restore(trainable_cls,
|
||||
config=None,
|
||||
num_gpus=0,
|
||||
use_object_store=False):
|
||||
"""Helper method to check if your Trainable class will resume correctly.
|
||||
|
||||
Args:
|
||||
trainable_cls: Trainable class for evaluation.
|
||||
config (dict): Config to pass to Trainable when testing.
|
||||
num_gpus (int): GPU resources to allocate when testing.
|
||||
use_object_store (bool): Whether to save and restore to Ray's object
|
||||
store. Recommended to set this to True if planning to use
|
||||
algorithms that pause training (i.e., PBT, HyperBand).
|
||||
"""
|
||||
assert ray.is_initialized(), "Need Ray to be initialized."
|
||||
remote_cls = ray.remote(trainable_cls)
|
||||
remote_cls = ray.remote(num_gpus=num_gpus)(trainable_cls)
|
||||
trainable_1 = remote_cls.remote(config=config)
|
||||
trainable_2 = remote_cls.remote(config=config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user