mirror of
https://github.com/vale981/ray
synced 2025-03-06 10:31:39 -05:00
[tune] allow to read trial results from json files in Analysis (#15915)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
cb878b6514
commit
e5b50fcc9d
2 changed files with 56 additions and 5 deletions
|
@ -17,13 +17,15 @@ except ImportError:
|
|||
|
||||
from ray.tune.error import TuneError
|
||||
from ray.tune.result import DEFAULT_METRIC, EXPR_PROGRESS_FILE, \
|
||||
EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
|
||||
EXPR_RESULT_FILE, EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
|
||||
from ray.tune.trial import Trial
|
||||
from ray.tune.utils.trainable import TrainableUtil
|
||||
from ray.tune.utils.util import unflattened_lookup
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FILE_TYPE = "csv"
|
||||
|
||||
|
||||
class Analysis:
|
||||
"""Analyze all results from a directory of experiments.
|
||||
|
@ -39,12 +41,15 @@ class Analysis:
|
|||
default_mode (str): Default mode for comparing results. Has to be one
|
||||
of [min, max]. Can be overwritten with the ``mode`` parameter
|
||||
in the respective functions.
|
||||
file_type (str): Read results from json or csv files. Has to be one
|
||||
of [None, json, csv]. Defaults to csv.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
experiment_dir: str,
|
||||
default_metric: Optional[str] = None,
|
||||
default_mode: Optional[str] = None):
|
||||
default_mode: Optional[str] = None,
|
||||
file_type: Optional[str] = None):
|
||||
experiment_dir = os.path.expanduser(experiment_dir)
|
||||
if not os.path.isdir(experiment_dir):
|
||||
raise ValueError(
|
||||
|
@ -58,6 +63,7 @@ class Analysis:
|
|||
raise ValueError(
|
||||
"`default_mode` has to be None or one of [min, max]")
|
||||
self.default_mode = default_mode
|
||||
self._file_type = self._validate_filetype(file_type)
|
||||
|
||||
if self.default_metric is None and self.default_mode:
|
||||
# If only a mode was passed, use anonymous metric
|
||||
|
@ -70,6 +76,23 @@ class Analysis:
|
|||
else:
|
||||
self.fetch_trial_dataframes()
|
||||
|
||||
def _validate_filetype(self, file_type: Optional[str] = None):
|
||||
if file_type not in {None, "json", "csv"}:
|
||||
raise ValueError(
|
||||
"`file_type` has to be None or one of [json, csv].")
|
||||
return file_type or DEFAULT_FILE_TYPE
|
||||
|
||||
def set_filetype(self, file_type: Optional[str] = None):
|
||||
"""Overrides the existing file type.
|
||||
|
||||
Args:
|
||||
file_type (str): Read results from json or csv files. Has to be one
|
||||
of [None, json, csv]. Defaults to csv.
|
||||
"""
|
||||
self._file_type = self._validate_filetype(file_type)
|
||||
self.fetch_trial_dataframes()
|
||||
return True
|
||||
|
||||
def _validate_metric(self, metric: str) -> str:
|
||||
if not metric and not self.default_metric:
|
||||
raise ValueError(
|
||||
|
@ -169,11 +192,21 @@ class Analysis:
|
|||
return None
|
||||
|
||||
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
|
||||
"""Fetches trial dataframes from files.
|
||||
|
||||
Returns:
|
||||
A dictionary containing "trial dir" to Dataframe.
|
||||
"""
|
||||
fail_count = 0
|
||||
for path in self._get_trial_paths():
|
||||
try:
|
||||
self.trial_dataframes[path] = pd.read_csv(
|
||||
os.path.join(path, EXPR_PROGRESS_FILE))
|
||||
if self._file_type == "json":
|
||||
with open(os.path.join(path, EXPR_RESULT_FILE), "r") as f:
|
||||
json_list = [json.loads(line) for line in f if line]
|
||||
df = pd.json_normalize(json_list, sep="/")
|
||||
elif self._file_type == "csv":
|
||||
df = pd.read_csv(os.path.join(path, EXPR_PROGRESS_FILE))
|
||||
self.trial_dataframes[path] = df
|
||||
except Exception:
|
||||
fail_count += 1
|
||||
|
||||
|
@ -325,7 +358,10 @@ class Analysis:
|
|||
def _get_trial_paths(self) -> List[str]:
|
||||
_trial_paths = []
|
||||
for trial_path, _, files in os.walk(self._experiment_dir):
|
||||
if EXPR_PROGRESS_FILE in files:
|
||||
if (self._file_type == "json"
|
||||
and EXPR_RESULT_FILE in files) \
|
||||
or (self._file_type == "csv"
|
||||
and EXPR_PROGRESS_FILE in files):
|
||||
_trial_paths += [trial_path]
|
||||
|
||||
if not _trial_paths:
|
||||
|
|
|
@ -69,6 +69,21 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
|||
self.assertTrue(isinstance(df, pd.DataFrame))
|
||||
self.assertEquals(df.shape[0], self.num_samples)
|
||||
|
||||
def testLoadJson(self):
|
||||
all_dataframes_via_csv = self.ea.fetch_trial_dataframes()
|
||||
|
||||
self.ea.set_filetype("json")
|
||||
all_dataframes_via_json = self.ea.fetch_trial_dataframes()
|
||||
|
||||
assert set(all_dataframes_via_csv) == set(all_dataframes_via_json)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
self.ea.set_filetype("bad")
|
||||
|
||||
self.ea.set_filetype("csv")
|
||||
all_dataframes_via_csv2 = self.ea.fetch_trial_dataframes()
|
||||
assert set(all_dataframes_via_csv) == set(all_dataframes_via_csv2)
|
||||
|
||||
def testStats(self):
|
||||
assert self.ea.stats()
|
||||
assert self.ea.runner_data()
|
||||
|
|
Loading…
Add table
Reference in a new issue