[tune] Added resources_per_trial arg to validate_save_restore u… (#6032)

This commit is contained in:
visatish 2019-11-04 13:24:46 -08:00 committed by Richard Liaw
parent c23eae5998
commit 18241f4a2d

View file

@ -200,18 +200,22 @@ def _from_pinnable(obj):
return obj[0]
def validate_save_restore(trainable_cls, config=None, use_object_store=False):
def validate_save_restore(trainable_cls,
config=None,
num_gpus=0,
use_object_store=False):
"""Helper method to check if your Trainable class will resume correctly.
Args:
trainable_cls: Trainable class for evaluation.
config (dict): Config to pass to Trainable when testing.
num_gpus (int): GPU resources to allocate when testing.
use_object_store (bool): Whether to save and restore to Ray's object
store. Recommended to set this to True if planning to use
algorithms that pause training (i.e., PBT, HyperBand).
"""
assert ray.is_initialized(), "Need Ray to be initialized."
remote_cls = ray.remote(trainable_cls)
remote_cls = ray.remote(num_gpus=num_gpus)(trainable_cls)
trainable_1 = remote_cls.remote(config=config)
trainable_2 = remote_cls.remote(config=config)