mirror of
https://github.com/vale981/ray
synced 2025-03-07 02:51:39 -05:00
[tune] Improve BOHB/ConfigSpace dependency check (#15064)
This commit is contained in:
parent
3965310f93
commit
8de66fce3d
1 changed files with 9 additions and 5 deletions
|
@ -5,7 +5,6 @@ import logging
|
|||
import math
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import ConfigSpace
|
||||
from ray.tune.result import DEFAULT_METRIC
|
||||
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
|
||||
Normal, \
|
||||
|
@ -17,6 +16,12 @@ from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE, \
|
|||
from ray.tune.suggest.variant_generator import parse_spec_vars
|
||||
from ray.tune.utils.util import flatten_dict, unflatten_list_dict
|
||||
|
||||
try:
|
||||
import ConfigSpace
|
||||
from hpbandster.optimizers.config_generators.bohb import BOHB
|
||||
except ImportError:
|
||||
BOHB = ConfigSpace = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -105,15 +110,14 @@ class TuneBOHB(Searcher):
|
|||
"""
|
||||
|
||||
def __init__(self,
|
||||
space: Optional[Union[Dict,
|
||||
ConfigSpace.ConfigurationSpace]] = None,
|
||||
space: Optional[Union[
|
||||
Dict, "ConfigSpace.ConfigurationSpace"]] = None,
|
||||
bohb_config: Optional[Dict] = None,
|
||||
max_concurrent: int = 10,
|
||||
metric: Optional[str] = None,
|
||||
mode: Optional[str] = None,
|
||||
points_to_evaluate: Optional[List[Dict]] = None,
|
||||
seed: Optional[int] = None):
|
||||
from hpbandster.optimizers.config_generators.bohb import BOHB
|
||||
assert BOHB is not None, """HpBandSter must be installed!
|
||||
You can install HpBandSter with the command:
|
||||
`pip install hpbandster ConfigSpace`."""
|
||||
|
@ -236,7 +240,7 @@ class TuneBOHB(Searcher):
|
|||
self.running.add(trial_id)
|
||||
|
||||
@staticmethod
|
||||
def convert_search_space(spec: Dict) -> ConfigSpace.ConfigurationSpace:
|
||||
def convert_search_space(spec: Dict) -> "ConfigSpace.ConfigurationSpace":
|
||||
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
|
||||
|
||||
if grid_vars:
|
||||
|
|
Loading…
Add table
Reference in a new issue