mirror of
https://github.com/vale981/ray
synced 2025-03-06 18:41:40 -05:00
[tune] allow to read trial results from json files in Analysis (#15915)
Co-authored-by: Richard Liaw <rliaw@berkeley.edu>
This commit is contained in:
parent
cb878b6514
commit
e5b50fcc9d
2 changed files with 56 additions and 5 deletions
|
@ -17,13 +17,15 @@ except ImportError:
|
||||||
|
|
||||||
from ray.tune.error import TuneError
|
from ray.tune.error import TuneError
|
||||||
from ray.tune.result import DEFAULT_METRIC, EXPR_PROGRESS_FILE, \
|
from ray.tune.result import DEFAULT_METRIC, EXPR_PROGRESS_FILE, \
|
||||||
EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
|
EXPR_RESULT_FILE, EXPR_PARAM_FILE, CONFIG_PREFIX, TRAINING_ITERATION
|
||||||
from ray.tune.trial import Trial
|
from ray.tune.trial import Trial
|
||||||
from ray.tune.utils.trainable import TrainableUtil
|
from ray.tune.utils.trainable import TrainableUtil
|
||||||
from ray.tune.utils.util import unflattened_lookup
|
from ray.tune.utils.util import unflattened_lookup
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_FILE_TYPE = "csv"
|
||||||
|
|
||||||
|
|
||||||
class Analysis:
|
class Analysis:
|
||||||
"""Analyze all results from a directory of experiments.
|
"""Analyze all results from a directory of experiments.
|
||||||
|
@ -39,12 +41,15 @@ class Analysis:
|
||||||
default_mode (str): Default mode for comparing results. Has to be one
|
default_mode (str): Default mode for comparing results. Has to be one
|
||||||
of [min, max]. Can be overwritten with the ``mode`` parameter
|
of [min, max]. Can be overwritten with the ``mode`` parameter
|
||||||
in the respective functions.
|
in the respective functions.
|
||||||
|
file_type (str): Read results from json or csv files. Has to be one
|
||||||
|
of [None, json, csv]. Defaults to csv.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
experiment_dir: str,
|
experiment_dir: str,
|
||||||
default_metric: Optional[str] = None,
|
default_metric: Optional[str] = None,
|
||||||
default_mode: Optional[str] = None):
|
default_mode: Optional[str] = None,
|
||||||
|
file_type: Optional[str] = None):
|
||||||
experiment_dir = os.path.expanduser(experiment_dir)
|
experiment_dir = os.path.expanduser(experiment_dir)
|
||||||
if not os.path.isdir(experiment_dir):
|
if not os.path.isdir(experiment_dir):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -58,6 +63,7 @@ class Analysis:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`default_mode` has to be None or one of [min, max]")
|
"`default_mode` has to be None or one of [min, max]")
|
||||||
self.default_mode = default_mode
|
self.default_mode = default_mode
|
||||||
|
self._file_type = self._validate_filetype(file_type)
|
||||||
|
|
||||||
if self.default_metric is None and self.default_mode:
|
if self.default_metric is None and self.default_mode:
|
||||||
# If only a mode was passed, use anonymous metric
|
# If only a mode was passed, use anonymous metric
|
||||||
|
@ -70,6 +76,23 @@ class Analysis:
|
||||||
else:
|
else:
|
||||||
self.fetch_trial_dataframes()
|
self.fetch_trial_dataframes()
|
||||||
|
|
||||||
|
def _validate_filetype(self, file_type: Optional[str] = None):
|
||||||
|
if file_type not in {None, "json", "csv"}:
|
||||||
|
raise ValueError(
|
||||||
|
"`file_type` has to be None or one of [json, csv].")
|
||||||
|
return file_type or DEFAULT_FILE_TYPE
|
||||||
|
|
||||||
|
def set_filetype(self, file_type: Optional[str] = None):
|
||||||
|
"""Overrides the existing file type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_type (str): Read results from json or csv files. Has to be one
|
||||||
|
of [None, json, csv]. Defaults to csv.
|
||||||
|
"""
|
||||||
|
self._file_type = self._validate_filetype(file_type)
|
||||||
|
self.fetch_trial_dataframes()
|
||||||
|
return True
|
||||||
|
|
||||||
def _validate_metric(self, metric: str) -> str:
|
def _validate_metric(self, metric: str) -> str:
|
||||||
if not metric and not self.default_metric:
|
if not metric and not self.default_metric:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -169,11 +192,21 @@ class Analysis:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
|
def fetch_trial_dataframes(self) -> Dict[str, DataFrame]:
|
||||||
|
"""Fetches trial dataframes from files.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary containing "trial dir" to Dataframe.
|
||||||
|
"""
|
||||||
fail_count = 0
|
fail_count = 0
|
||||||
for path in self._get_trial_paths():
|
for path in self._get_trial_paths():
|
||||||
try:
|
try:
|
||||||
self.trial_dataframes[path] = pd.read_csv(
|
if self._file_type == "json":
|
||||||
os.path.join(path, EXPR_PROGRESS_FILE))
|
with open(os.path.join(path, EXPR_RESULT_FILE), "r") as f:
|
||||||
|
json_list = [json.loads(line) for line in f if line]
|
||||||
|
df = pd.json_normalize(json_list, sep="/")
|
||||||
|
elif self._file_type == "csv":
|
||||||
|
df = pd.read_csv(os.path.join(path, EXPR_PROGRESS_FILE))
|
||||||
|
self.trial_dataframes[path] = df
|
||||||
except Exception:
|
except Exception:
|
||||||
fail_count += 1
|
fail_count += 1
|
||||||
|
|
||||||
|
@ -325,7 +358,10 @@ class Analysis:
|
||||||
def _get_trial_paths(self) -> List[str]:
|
def _get_trial_paths(self) -> List[str]:
|
||||||
_trial_paths = []
|
_trial_paths = []
|
||||||
for trial_path, _, files in os.walk(self._experiment_dir):
|
for trial_path, _, files in os.walk(self._experiment_dir):
|
||||||
if EXPR_PROGRESS_FILE in files:
|
if (self._file_type == "json"
|
||||||
|
and EXPR_RESULT_FILE in files) \
|
||||||
|
or (self._file_type == "csv"
|
||||||
|
and EXPR_PROGRESS_FILE in files):
|
||||||
_trial_paths += [trial_path]
|
_trial_paths += [trial_path]
|
||||||
|
|
||||||
if not _trial_paths:
|
if not _trial_paths:
|
||||||
|
|
|
@ -69,6 +69,21 @@ class ExperimentAnalysisSuite(unittest.TestCase):
|
||||||
self.assertTrue(isinstance(df, pd.DataFrame))
|
self.assertTrue(isinstance(df, pd.DataFrame))
|
||||||
self.assertEquals(df.shape[0], self.num_samples)
|
self.assertEquals(df.shape[0], self.num_samples)
|
||||||
|
|
||||||
|
def testLoadJson(self):
|
||||||
|
all_dataframes_via_csv = self.ea.fetch_trial_dataframes()
|
||||||
|
|
||||||
|
self.ea.set_filetype("json")
|
||||||
|
all_dataframes_via_json = self.ea.fetch_trial_dataframes()
|
||||||
|
|
||||||
|
assert set(all_dataframes_via_csv) == set(all_dataframes_via_json)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
self.ea.set_filetype("bad")
|
||||||
|
|
||||||
|
self.ea.set_filetype("csv")
|
||||||
|
all_dataframes_via_csv2 = self.ea.fetch_trial_dataframes()
|
||||||
|
assert set(all_dataframes_via_csv) == set(all_dataframes_via_csv2)
|
||||||
|
|
||||||
def testStats(self):
|
def testStats(self):
|
||||||
assert self.ea.stats()
|
assert self.ea.stats()
|
||||||
assert self.ea.runner_data()
|
assert self.ea.runner_data()
|
||||||
|
|
Loading…
Add table
Reference in a new issue