mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[tune] Added resources_per_trial arg to validate_save_restore u… (#6032)
This commit is contained in:
parent
c23eae5998
commit
18241f4a2d
1 changed files with 6 additions and 2 deletions
|
@ -200,18 +200,22 @@ def _from_pinnable(obj):
|
|||
return obj[0]
|
||||
|
||||
|
||||
def validate_save_restore(trainable_cls, config=None, use_object_store=False):
|
||||
def validate_save_restore(trainable_cls,
|
||||
config=None,
|
||||
num_gpus=0,
|
||||
use_object_store=False):
|
||||
"""Helper method to check if your Trainable class will resume correctly.
|
||||
|
||||
Args:
|
||||
trainable_cls: Trainable class for evaluation.
|
||||
config (dict): Config to pass to Trainable when testing.
|
||||
num_gpus (int): GPU resources to allocate when testing.
|
||||
use_object_store (bool): Whether to save and restore to Ray's object
|
||||
store. Recommended to set this to True if planning to use
|
||||
algorithms that pause training (i.e., PBT, HyperBand).
|
||||
"""
|
||||
assert ray.is_initialized(), "Need Ray to be initialized."
|
||||
remote_cls = ray.remote(trainable_cls)
|
||||
remote_cls = ray.remote(num_gpus=num_gpus)(trainable_cls)
|
||||
trainable_1 = remote_cls.remote(config=config)
|
||||
trainable_2 = remote_cls.remote(config=config)
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue