[tune] Improve BOHB/ConfigSpace dependency check (#15064)

This commit is contained in:
Kai Fricke 2021-04-02 10:19:49 +02:00 committed by GitHub
parent 3965310f93
commit 8de66fce3d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,7 +5,6 @@ import logging
import math
from typing import Dict, List, Optional, Union
import ConfigSpace
from ray.tune.result import DEFAULT_METRIC
from ray.tune.sample import Categorical, Domain, Float, Integer, LogUniform, \
Normal, \
@ -17,6 +16,12 @@ from ray.tune.suggest.suggestion import UNRESOLVED_SEARCH_SPACE, \
from ray.tune.suggest.variant_generator import parse_spec_vars
from ray.tune.utils.util import flatten_dict, unflatten_list_dict
try:
import ConfigSpace
from hpbandster.optimizers.config_generators.bohb import BOHB
except ImportError:
BOHB = ConfigSpace = None
logger = logging.getLogger(__name__)
@ -105,15 +110,14 @@ class TuneBOHB(Searcher):
"""
def __init__(self,
space: Optional[Union[Dict,
ConfigSpace.ConfigurationSpace]] = None,
space: Optional[Union[
Dict, "ConfigSpace.ConfigurationSpace"]] = None,
bohb_config: Optional[Dict] = None,
max_concurrent: int = 10,
metric: Optional[str] = None,
mode: Optional[str] = None,
points_to_evaluate: Optional[List[Dict]] = None,
seed: Optional[int] = None):
from hpbandster.optimizers.config_generators.bohb import BOHB
assert BOHB is not None, """HpBandSter must be installed!
You can install HpBandSter with the command:
`pip install hpbandster ConfigSpace`."""
@ -236,7 +240,7 @@ class TuneBOHB(Searcher):
self.running.add(trial_id)
@staticmethod
def convert_search_space(spec: Dict) -> ConfigSpace.ConfigurationSpace:
def convert_search_space(spec: Dict) -> "ConfigSpace.ConfigurationSpace":
resolved_vars, domain_vars, grid_vars = parse_spec_vars(spec)
if grid_vars: