[tune] allow to read trial results from json files in Analysis (#15915)

Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
lanlin 2021-06-21 11:41:48 +08:00 committed by GitHub
parent cb878b6514
commit e5b50fcc9d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 5 deletions

View file

@ -17,13 +17,15 @@ except ImportError:
from ray.tune.error import TuneError
from ray.tune.result import DEFAULT_METRIC, EXPR_PROGRESS_FILE, \
EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
EXPR_RESULT_FILE, EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
from ray.tune.trial import Trial
from ray.tune.utils.trainable import TrainableUtil
from ray.tune.utils.util import unflattened_lookup
logger = logging.getLogger(__name__)
DEFAULT_FILE_TYPE = "csv"
class Analysis:
"""Analyze all results from a directory of experiments.
@ -39,12 +41,15 @@ class Analysis:
default_mode (str): Default mode for comparing results. Has to be one
of [min, max]. Can be overwritten with the ``mode`` parameter
in the respective functions.
file_type (str): Read results from json or csv files. Has to be one
of [None, json, csv]. Defaults to csv.
"""
def __init__(self,
experiment_dir: str,
default_metric: Optional[str] = None,
default_mode: Optional[str] = None):
default_mode: Optional[str] = None,
file_type: Optional[str] = None):
experiment_dir = os.path.expanduser(experiment_dir)
if not os.path.isdir(experiment_dir):
raise ValueError(
@ -58,6 +63,7 @@ class Analysis:
raise ValueError(
"`default_mode` has to be None or one of [min, max]")
self.default_mode = default_mode
self._file_type = self._validate_filetype(file_type)
if self.default_metric is None and self.default_mode:
# If only a mode was passed, use anonymous metric
@ -70,6 +76,23 @@ class Analysis:
else:
self.fetch_trial_dataframes()
def _validate_filetype(self, file_type: Optional[str] = None):
if file_type not in {None, "json", "csv"}:
raise ValueError(
"`file_type` has to be None or one of [json, csv].")
return file_type or DEFAULT_FILE_TYPE
def set_filetype(self, file_type: Optional[str] = None):
"""Overrides the existing file type.
Args:
file_type (str): Read results from json or csv files. Has to be one
of [None, json, csv]. Defaults to csv.
"""
self._file_type = self._validate_filetype(file_type)
self.fetch_trial_dataframes()
return True
def _validate_metric(self, metric: str) -> str:
if not metric and not self.default_metric:
raise ValueError(
@ -169,11 +192,21 @@ class Analysis:
return None
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
"""Fetches trial dataframes from files.
Returns:
A dictionary containing "trial dir" to Dataframe.
"""
fail_count = 0
for path in self._get_trial_paths():
try:
self.trial_dataframes[path] = pd.read_csv(
os.path.join(path, EXPR_PROGRESS_FILE))
if self._file_type == "json":
with open(os.path.join(path, EXPR_RESULT_FILE), "r") as f:
json_list = [json.loads(line) for line in f if line]
df = pd.json_normalize(json_list, sep="/")
elif self._file_type == "csv":
df = pd.read_csv(os.path.join(path, EXPR_PROGRESS_FILE))
self.trial_dataframes[path] = df
except Exception:
fail_count += 1
@ -325,7 +358,10 @@ class Analysis:
def _get_trial_paths(self) -> List[str]:
_trial_paths = []
for trial_path, _, files in os.walk(self._experiment_dir):
if EXPR_PROGRESS_FILE in files:
if (self._file_type == "json"
and EXPR_RESULT_FILE in files) \
or (self._file_type == "csv"
and EXPR_PROGRESS_FILE in files):
_trial_paths += [trial_path]
if not _trial_paths:

View file

@ -69,6 +69,21 @@ class ExperimentAnalysisSuite(unittest.TestCase):
self.assertTrue(isinstance(df, pd.DataFrame))
self.assertEquals(df.shape[0], self.num_samples)
def testLoadJson(self):
all_dataframes_via_csv = self.ea.fetch_trial_dataframes()
self.ea.set_filetype("json")
all_dataframes_via_json = self.ea.fetch_trial_dataframes()
assert set(all_dataframes_via_csv) == set(all_dataframes_via_json)
with self.assertRaises(ValueError):
self.ea.set_filetype("bad")
self.ea.set_filetype("csv")
all_dataframes_via_csv2 = self.ea.fetch_trial_dataframes()
assert set(all_dataframes_via_csv) == set(all_dataframes_via_csv2)
def testStats(self):
assert self.ea.stats()
assert self.ea.runner_data()