From 714193ce6fe1a6b6a6c6ae6c936e8effaff372e7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Sat, 28 Aug 2021 02:50:26 +0200 Subject: [PATCH] [SGDv2] Tensorboard Callback (#17824) * [SGD] save checkpoints to disk * fix test; add logs * Extend SGDv2 callback API * Move json file creation to JsonLoggerCallback * TBXLoggerCallback * Simplify, fix linear example * rename log_dir to logdir for consistency with tune * Add test * Fix * Break up logging classes * Fix error * Update type hint for results * Refactor Co-authored-by: Matthew Deng --- python/ray/tune/utils/util.py | 211 +---------------- python/ray/util/ml_utils/dict.py | 216 ++++++++++++++++++ python/ray/util/sgd/v2/backends/backend.py | 3 +- python/ray/util/sgd/v2/callbacks/__init__.py | 5 +- python/ray/util/sgd/v2/callbacks/callback.py | 29 ++- python/ray/util/sgd/v2/callbacks/logging.py | 212 ++++++++++++++--- python/ray/util/sgd/v2/constants.py | 2 +- .../ray/util/sgd/v2/examples/train_linear.py | 11 +- .../ray/util/sgd/v2/tests/test_callbacks.py | 54 ++++- python/ray/util/sgd/v2/tests/test_trainer.py | 1 - python/ray/util/sgd/v2/trainer.py | 2 +- 11 files changed, 488 insertions(+), 258 deletions(-) create mode 100644 python/ray/util/ml_utils/dict.py diff --git a/python/ray/tune/utils/util.py b/python/ray/tune/utils/util.py index 9612f1cc2..0f4612c66 100644 --- a/python/ray/tune/utils/util.py +++ b/python/ray/tune/utils/util.py @@ -9,8 +9,7 @@ import inspect import threading import time import uuid -from collections import defaultdict, deque -from collections.abc import Mapping, Sequence +from collections import defaultdict from datetime import datetime from threading import Thread from typing import Optional @@ -20,6 +19,9 @@ import ray import psutil from ray.util.ml_utils.json import SafeFallbackEncoder # noqa +from ray.util.ml_utils.dict import merge_dicts, deep_update, flatten_dict, \ + unflatten_dict, unflatten_list_dict, \ + unflattened_lookup # noqa logger = logging.getLogger(__name__) @@ -220,211 +222,6 @@ def is_nan_or_inf(value): return np.isnan(value) or np.isinf(value) -def merge_dicts(d1, d2): - """ - Args: - d1 (dict): Dict 1. - d2 (dict): Dict 2. - - Returns: - dict: A new dict that is d1 and d2 deep merged. - """ - merged = copy.deepcopy(d1) - deep_update(merged, d2, True, []) - return merged - - -def deep_update(original, - new_dict, - new_keys_allowed=False, - allow_new_subkey_list=None, - override_all_if_type_changes=None): - """Updates original dict with values from new_dict recursively. - - If new key is introduced in new_dict, then if new_keys_allowed is not - True, an error will be thrown. Further, for sub-dicts, if the key is - in the allow_new_subkey_list, then new subkeys can be introduced. - - Args: - original (dict): Dictionary with default values. - new_dict (dict): Dictionary with values to be updated - new_keys_allowed (bool): Whether new keys are allowed. - allow_new_subkey_list (Optional[List[str]]): List of keys that - correspond to dict values where new subkeys can be introduced. - This is only at the top level. - override_all_if_type_changes(Optional[List[str]]): List of top level - keys with value=dict, for which we always simply override the - entire value (dict), iff the "type" key in that value dict changes. - """ - allow_new_subkey_list = allow_new_subkey_list or [] - override_all_if_type_changes = override_all_if_type_changes or [] - - for k, value in new_dict.items(): - if k not in original and not new_keys_allowed: - raise Exception("Unknown config parameter `{}` ".format(k)) - - # Both orginal value and new one are dicts. - if isinstance(original.get(k), dict) and isinstance(value, dict): - # Check old type vs old one. If different, override entire value. - if k in override_all_if_type_changes and \ - "type" in value and "type" in original[k] and \ - value["type"] != original[k]["type"]: - original[k] = value - # Allowed key -> ok to add new subkeys. - elif k in allow_new_subkey_list: - deep_update(original[k], value, True) - # Non-allowed key. - else: - deep_update(original[k], value, new_keys_allowed) - # Original value not a dict OR new value not a dict: - # Override entire value. - else: - original[k] = value - return original - - -def flatten_dict(dt: Dict, - delimiter: str = "/", - prevent_delimiter: bool = False, - flatten_list: bool = False): - """Flatten dict. - - Output and input are of the same dict type. - Input dict remains the same after the operation. - """ - - def _raise_delimiter_exception(): - raise ValueError( - f"Found delimiter `{delimiter}` in key when trying to flatten " - f"array. Please avoid using the delimiter in your specification.") - - dt = copy.copy(dt) - if prevent_delimiter and any(delimiter in key for key in dt): - # Raise if delimiter is any of the keys - _raise_delimiter_exception() - - while_check = (dict, list) if flatten_list else dict - - while any(isinstance(v, while_check) for v in dt.values()): - remove = [] - add = {} - for key, value in dt.items(): - if isinstance(value, dict): - for subkey, v in value.items(): - if prevent_delimiter and delimiter in subkey: - # Raise if delimiter is in any of the subkeys - _raise_delimiter_exception() - - add[delimiter.join([key, str(subkey)])] = v - remove.append(key) - elif flatten_list and isinstance(value, list): - for i, v in enumerate(value): - if prevent_delimiter and delimiter in subkey: - # Raise if delimiter is in any of the subkeys - _raise_delimiter_exception() - - add[delimiter.join([key, str(i)])] = v - remove.append(key) - - dt.update(add) - for k in remove: - del dt[k] - return dt - - -def unflatten_dict(dt, delimiter="/"): - """Unflatten dict. Does not support unflattening lists.""" - dict_type = type(dt) - out = dict_type() - for key, val in dt.items(): - path = key.split(delimiter) - item = out - for k in path[:-1]: - item = item.setdefault(k, dict_type()) - if not isinstance(item, dict_type): - raise TypeError( - f"Cannot unflatten dict due the key '{key}' " - f"having a parent key '{k}', which value is not " - f"of type {dict_type} (got {type(item)}). " - "Change the key names to resolve the conflict.") - item[path[-1]] = val - return out - - -def unflatten_list_dict(dt, delimiter="/"): - """Unflatten nested dict and list. - - This function now has some limitations: - (1) The keys of dt must be str. - (2) If unflattened dt (the result) contains list, the index order must be - ascending when accessing dt. Otherwise, this function will throw - AssertionError. - (3) The unflattened dt (the result) shouldn't contain dict with number - keys. - - Be careful to use this function. If you want to improve this function, - please also improve the unit test. See #14487 for more details. - - Args: - dt (dict): Flattened dictionary that is originally nested by multiple - list and dict. - delimiter (str): Delimiter of keys. - - Example: - >>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92} - >>> unflatten_list_dict(dt) - {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]} - """ - out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() \ - else type(dt) - out = out_type() - for key, val in dt.items(): - path = key.split(delimiter) - - item = out - for i, k in enumerate(path[:-1]): - next_type = list if path[i + 1].isdigit() else dict - if isinstance(item, dict): - item = item.setdefault(k, next_type()) - elif isinstance(item, list): - if int(k) >= len(item): - item.append(next_type()) - assert int(k) == len(item) - 1 - item = item[int(k)] - - if isinstance(item, dict): - item[path[-1]] = val - elif isinstance(item, list): - item.append(val) - assert int(path[-1]) == len(item) - 1 - return out - - -def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs): - """ - Unflatten `flat_key` and iteratively look up in `lookup`. E.g. - `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`. - """ - if flat_key in lookup: - return lookup[flat_key] - keys = deque(flat_key.split(delimiter)) - base = lookup - while keys: - key = keys.popleft() - try: - if isinstance(base, Mapping): - base = base[key] - elif isinstance(base, Sequence): - base = base[int(key)] - else: - raise KeyError() - except KeyError as e: - if "default" in kwargs: - return kwargs["default"] - raise e - return base - - def _to_pinnable(obj): """Converts obj to a form that can be pinned in object store memory. diff --git a/python/ray/util/ml_utils/dict.py b/python/ray/util/ml_utils/dict.py new file mode 100644 index 000000000..b9b23c0c5 --- /dev/null +++ b/python/ray/util/ml_utils/dict.py @@ -0,0 +1,216 @@ +from typing import Dict, List, Union, Optional, TypeVar +import copy +from collections import deque +from collections.abc import Mapping, Sequence + +T = TypeVar("T") + + +def merge_dicts(d1: dict, d2: dict) -> dict: + """ + Args: + d1 (dict): Dict 1. + d2 (dict): Dict 2. + + Returns: + dict: A new dict that is d1 and d2 deep merged. + """ + merged = copy.deepcopy(d1) + deep_update(merged, d2, True, []) + return merged + + +def deep_update( + original: dict, + new_dict: dict, + new_keys_allowed: str = False, + allow_new_subkey_list: Optional[List[str]] = None, + override_all_if_type_changes: Optional[List[str]] = None) -> dict: + """Updates original dict with values from new_dict recursively. + + If new key is introduced in new_dict, then if new_keys_allowed is not + True, an error will be thrown. Further, for sub-dicts, if the key is + in the allow_new_subkey_list, then new subkeys can be introduced. + + Args: + original (dict): Dictionary with default values. + new_dict (dict): Dictionary with values to be updated + new_keys_allowed (bool): Whether new keys are allowed. + allow_new_subkey_list (Optional[List[str]]): List of keys that + correspond to dict values where new subkeys can be introduced. + This is only at the top level. + override_all_if_type_changes(Optional[List[str]]): List of top level + keys with value=dict, for which we always simply override the + entire value (dict), iff the "type" key in that value dict changes. + """ + allow_new_subkey_list = allow_new_subkey_list or [] + override_all_if_type_changes = override_all_if_type_changes or [] + + for k, value in new_dict.items(): + if k not in original and not new_keys_allowed: + raise Exception("Unknown config parameter `{}` ".format(k)) + + # Both orginal value and new one are dicts. + if isinstance(original.get(k), dict) and isinstance(value, dict): + # Check old type vs old one. If different, override entire value. + if k in override_all_if_type_changes and \ + "type" in value and "type" in original[k] and \ + value["type"] != original[k]["type"]: + original[k] = value + # Allowed key -> ok to add new subkeys. + elif k in allow_new_subkey_list: + deep_update(original[k], value, True) + # Non-allowed key. + else: + deep_update(original[k], value, new_keys_allowed) + # Original value not a dict OR new value not a dict: + # Override entire value. + else: + original[k] = value + return original + + +def flatten_dict(dt: Dict, + delimiter: str = "/", + prevent_delimiter: bool = False, + flatten_list: bool = False): + """Flatten dict. + + Output and input are of the same dict type. + Input dict remains the same after the operation. + """ + + def _raise_delimiter_exception(): + raise ValueError( + f"Found delimiter `{delimiter}` in key when trying to flatten " + f"array. Please avoid using the delimiter in your specification.") + + dt = copy.copy(dt) + if prevent_delimiter and any(delimiter in key for key in dt): + # Raise if delimiter is any of the keys + _raise_delimiter_exception() + + while_check = (dict, list) if flatten_list else dict + + while any(isinstance(v, while_check) for v in dt.values()): + remove = [] + add = {} + for key, value in dt.items(): + if isinstance(value, dict): + for subkey, v in value.items(): + if prevent_delimiter and delimiter in subkey: + # Raise if delimiter is in any of the subkeys + _raise_delimiter_exception() + + add[delimiter.join([key, str(subkey)])] = v + remove.append(key) + elif flatten_list and isinstance(value, list): + for i, v in enumerate(value): + if prevent_delimiter and delimiter in subkey: + # Raise if delimiter is in any of the subkeys + _raise_delimiter_exception() + + add[delimiter.join([key, str(i)])] = v + remove.append(key) + + dt.update(add) + for k in remove: + del dt[k] + return dt + + +def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]: + """Unflatten dict. Does not support unflattening lists.""" + dict_type = type(dt) + out = dict_type() + for key, val in dt.items(): + path = key.split(delimiter) + item = out + for k in path[:-1]: + item = item.setdefault(k, dict_type()) + if not isinstance(item, dict_type): + raise TypeError( + f"Cannot unflatten dict due the key '{key}' " + f"having a parent key '{k}', which value is not " + f"of type {dict_type} (got {type(item)}). " + "Change the key names to resolve the conflict.") + item[path[-1]] = val + return out + + +def unflatten_list_dict(dt: Dict[str, T], + delimiter: str = "/") -> Dict[str, T]: + """Unflatten nested dict and list. + + This function now has some limitations: + (1) The keys of dt must be str. + (2) If unflattened dt (the result) contains list, the index order must be + ascending when accessing dt. Otherwise, this function will throw + AssertionError. + (3) The unflattened dt (the result) shouldn't contain dict with number + keys. + + Be careful to use this function. If you want to improve this function, + please also improve the unit test. See #14487 for more details. + + Args: + dt (dict): Flattened dictionary that is originally nested by multiple + list and dict. + delimiter (str): Delimiter of keys. + + Example: + >>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92} + >>> unflatten_list_dict(dt) + {'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]} + """ + out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() \ + else type(dt) + out = out_type() + for key, val in dt.items(): + path = key.split(delimiter) + + item = out + for i, k in enumerate(path[:-1]): + next_type = list if path[i + 1].isdigit() else dict + if isinstance(item, dict): + item = item.setdefault(k, next_type()) + elif isinstance(item, list): + if int(k) >= len(item): + item.append(next_type()) + assert int(k) == len(item) - 1 + item = item[int(k)] + + if isinstance(item, dict): + item[path[-1]] = val + elif isinstance(item, list): + item.append(val) + assert int(path[-1]) == len(item) - 1 + return out + + +def unflattened_lookup(flat_key: str, + lookup: Union[Mapping, Sequence], + delimiter: str = "/", + **kwargs) -> Union[Mapping, Sequence]: + """ + Unflatten `flat_key` and iteratively look up in `lookup`. E.g. + `flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`. + """ + if flat_key in lookup: + return lookup[flat_key] + keys = deque(flat_key.split(delimiter)) + base = lookup + while keys: + key = keys.popleft() + try: + if isinstance(base, Mapping): + base = base[key] + elif isinstance(base, Sequence): + base = base[int(key)] + else: + raise KeyError() + except KeyError as e: + if "default" in kwargs: + return kwargs["default"] + raise e + return base diff --git a/python/ray/util/sgd/v2/backends/backend.py b/python/ray/util/sgd/v2/backends/backend.py index b2605df20..4b832a808 100644 --- a/python/ray/util/sgd/v2/backends/backend.py +++ b/python/ray/util/sgd/v2/backends/backend.py @@ -50,7 +50,8 @@ class BackendExecutor: generated. Attributes: - logdir (Path): Path to the file directory where logs will be persisted. + logdir (Path): Path to the file directory where logs will be + persisted. latest_run_dir (Optional[Path]): Path to the file directory for the latest run. Configured through ``start_training``. latest_checkpoint_dir (Optional[Path]): Path to the file directory for diff --git a/python/ray/util/sgd/v2/callbacks/__init__.py b/python/ray/util/sgd/v2/callbacks/__init__.py index abea77f4d..9ebf4fe92 100644 --- a/python/ray/util/sgd/v2/callbacks/__init__.py +++ b/python/ray/util/sgd/v2/callbacks/__init__.py @@ -1,4 +1,5 @@ from ray.util.sgd.v2.callbacks.callback import SGDCallback -from ray.util.sgd.v2.callbacks.logging import JsonLoggerCallback +from ray.util.sgd.v2.callbacks.logging import (JsonLoggerCallback, + TBXLoggerCallback) -__all__ = ["SGDCallback", "JsonLoggerCallback"] +__all__ = ["SGDCallback", "JsonLoggerCallback", "TBXLoggerCallback"] diff --git a/python/ray/util/sgd/v2/callbacks/callback.py b/python/ray/util/sgd/v2/callbacks/callback.py index a2e6704a7..3d8e00e53 100644 --- a/python/ray/util/sgd/v2/callbacks/callback.py +++ b/python/ray/util/sgd/v2/callbacks/callback.py @@ -1,21 +1,36 @@ import abc -from typing import List, Optional, Dict +from typing import List, Dict class SGDCallback(metaclass=abc.ABCMeta): """Abstract SGD callback class.""" - def handle_result(self, results: Optional[List[Dict]]): - """Called every time sgd.report() is called.""" + def handle_result(self, results: List[Dict], **info): + """Called every time sgd.report() is called. + + Args: + results (List[Dict]): List of results from the training + function. Each value in the list corresponds to the output of + the training function from each worker. + **info: kwargs dict for forward compatibility. + """ pass - def start_training(self): - """Called once on training start.""" + def start_training(self, logdir: str, **info): + """Called once on training start. + + Args: + logdir (str): Path to the file directory where logs + should be persisted. + **info: kwargs dict for forward compatibility. + """ pass - def finish_training(self, error: bool = False): + def finish_training(self, error: bool = False, **info): """Called once after training is over. Args: - error (bool): If True, there was an exception during training.""" + error (bool): If True, there was an exception during training. + **info: kwargs dict for forward compatibility. + """ pass diff --git a/python/ray/util/sgd/v2/callbacks/logging.py b/python/ray/util/sgd/v2/callbacks/logging.py index 8df3c6be8..a317e3770 100644 --- a/python/ray/util/sgd/v2/callbacks/logging.py +++ b/python/ray/util/sgd/v2/callbacks/logging.py @@ -1,22 +1,52 @@ -from typing import Iterable, List, Optional, Dict, Union +from typing import Iterable, List, Optional, Dict, Set, Tuple, Union import abc - +import warnings +import logging +import numpy as np import json from pathlib import Path +from ray.util.debug import log_once +from ray.util.ml_utils.dict import flatten_dict from ray.util.ml_utils.json import SafeFallbackEncoder from ray.util.sgd.v2.callbacks import SGDCallback -from ray.util.sgd.v2.constants import RESULT_FILE_JSON +from ray.util.sgd.v2.constants import (RESULT_FILE_JSON, TRAINING_ITERATION, + TIME_TOTAL_S, TIMESTAMP, PID) + +logger = logging.getLogger(__name__) -class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta): +class SGDLogdirMixin: + def start_training(self, logdir: str, **info): + if self._logdir: + logdir_path = Path(self._logdir) + else: + logdir_path = Path(logdir) + + if not logdir_path.is_dir(): + raise ValueError(f"logdir '{logdir}' must be a directory.") + + self._logdir_path = logdir_path + + @property + def logdir(self) -> Optional[Path]: + """Path to currently used logging directory.""" + if not hasattr(self, "_logdir_path"): + return Path(self._logdir) + return Path(self._logdir_path) + + +class SGDSingleFileLoggingCallback( + SGDLogdirMixin, SGDCallback, metaclass=abc.ABCMeta): """Abstract SGD logging callback class. Args: - logdir (str): Path to directory where the results file should be. - filename (str|None): Filename in logdir to save results to. + logdir (Optional[str]): Path to directory where the results file + should be. If None, will be set by the Trainer. + filename (Optional[str]): Filename in logdir to save results to. workers_to_log (int|List[int]|None): Worker indices to log. - If None, will log all workers. + If None, will log all workers. By default, will log the + worker with index 0. """ # Defining it like this ensures it will be overwritten @@ -24,18 +54,15 @@ class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta): _default_filename: Union[str, Path] def __init__(self, - logdir: str, + logdir: Optional[str] = None, filename: Optional[str] = None, workers_to_log: Optional[Union[int, List[int]]] = 0) -> None: - logdir_path = Path(logdir) + self._logdir = logdir + self._filename = filename + self._workers_to_log = self._validate_workers_to_log(workers_to_log) + self._log_path = None - if not logdir_path.is_dir(): - raise ValueError(f"logdir '{logdir}' must be a directory.") - - if filename is None: - filename = self._default_filename - - self._log_path = logdir_path.joinpath(Path(filename)) + def _validate_workers_to_log(self, workers_to_log) -> List[int]: if isinstance(workers_to_log, int): workers_to_log = [workers_to_log] @@ -46,35 +73,58 @@ class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta): if not all(isinstance(worker, int) for worker in workers_to_log): raise TypeError( "All elements of workers_to_log must be integers.") + if len(workers_to_log) < 1: + raise ValueError( + "At least one worker must be specified in workers_to_log.") + return workers_to_log - self._workers_to_log = workers_to_log + def _create_log_path(self, logdir_path: Path, filename: Path) -> Path: + if not filename: + raise ValueError("filename cannot be None or empty.") + return logdir_path.joinpath(Path(filename)) + + def start_training(self, logdir: str, **info): + super().start_training(logdir, **info) + + if not self._filename: + filename = self._default_filename + else: + filename = self._filename + + self._log_path = self._create_log_path(self.logdir, filename) @property - def log_path(self) -> Path: - """Path to the log file.""" - return self._log_path + def log_path(self) -> Optional[Path]: + """Path to the log file. - def start_training(self): - # Create a JSON file with an empty list - # that will be latter appended to - with open(self._log_path, "w") as f: - json.dump([], f, cls=SafeFallbackEncoder) + Will be None before `start_training` is called for the first time. + """ + return self._log_path class JsonLoggerCallback(SGDSingleFileLoggingCallback): """Logs SGD results in json format. Args: - logdir (str): Path to directory where the results file should be. - filename (str|None): Filename in logdir to save results to. - Defaults to "results.json". + logdir (Optional[str]): Path to directory where the results file + should be. If None, will be set by the Trainer. + filename (Optional[str]): Filename in logdir to save results to. workers_to_log (int|List[int]|None): Worker indices to log. - If None, will log all workers. + If None, will log all workers. By default, will log the + worker with index 0. """ _default_filename: Union[str, Path] = RESULT_FILE_JSON - def handle_result(self, results: Optional[List[Dict]]): + def start_training(self, logdir: str, **info): + super().start_training(logdir, **info) + + # Create a JSON file with an empty list + # that will be latter appended to + with open(self._log_path, "w") as f: + json.dump([], f, cls=SafeFallbackEncoder) + + def handle_result(self, results: List[Dict], **info): if self._workers_to_log is None or results is None: results_to_log = results else: @@ -87,3 +137,105 @@ class JsonLoggerCallback(SGDSingleFileLoggingCallback): f.seek(0) json.dump( loaded_results + [results_to_log], f, cls=SafeFallbackEncoder) + + +class SGDSingleWorkerLoggingCallback( + SGDLogdirMixin, SGDCallback, metaclass=abc.ABCMeta): + """Abstract SGD logging callback class. + + Allows only for single-worker logging. + + Args: + logdir (Optional[str]): Path to directory where the results file + should be. If None, will be set by the Trainer. + worker_to_log (int): Worker index to log. By default, will log the + worker with index 0. + """ + + def __init__(self, logdir: Optional[str] = None, + worker_to_log: int = 0) -> None: + self._logdir = logdir + self._workers_to_log = self._validate_worker_to_log(worker_to_log) + self._log_path = None + + def _validate_worker_to_log(self, worker_to_log) -> int: + if isinstance(worker_to_log, Iterable): + worker_to_log = list(worker_to_log) + if len(worker_to_log) > 1: + raise ValueError( + f"{self.__class__.__name__} only supports logging " + "from a single worker.") + elif len(worker_to_log) < 1: + raise ValueError( + "At least one worker must be specified in workers_to_log.") + worker_to_log = worker_to_log[0] + if not isinstance(worker_to_log, int): + raise TypeError("workers_to_log must be an integer.") + return worker_to_log + + +class TBXLoggerCallback(SGDSingleWorkerLoggingCallback): + """Logs SGD results in TensorboardX format. + + Args: + logdir (Optional[str]): Path to directory where the results file + should be. If None, will be set by the Trainer. + worker_to_log (int): Worker index to log. By default, will log the + worker with index 0. + """ + + VALID_SUMMARY_TYPES: Tuple[type] = (int, float, np.float32, np.float64, + np.int32, np.int64) + IGNORE_KEYS: Set[str] = {PID, TIMESTAMP, TIME_TOTAL_S, TRAINING_ITERATION} + + def start_training(self, logdir: str, **info): + super().start_training(logdir, **info) + + try: + from tensorboardX import SummaryWriter + except ImportError: + if log_once("tbx-install"): + warnings.warn( + "pip install 'tensorboardX' to see TensorBoard files.") + raise + + self._file_writer = SummaryWriter(self.logdir, flush_secs=30) + + def handle_result(self, results: List[Dict], **info): + result = results[self._workers_to_log] + step = result[TRAINING_ITERATION] + result = {k: v for k, v in result.items() if k not in self.IGNORE_KEYS} + flat_result = flatten_dict(result, delimiter="/") + path = ["ray", "sgd"] + + # same logic as in ray.tune.logger.TBXLogger + for attr, value in flat_result.items(): + full_attr = "/".join(path + [attr]) + if (isinstance(value, self.VALID_SUMMARY_TYPES) + and not np.isnan(value)): + self._file_writer.add_scalar( + full_attr, value, global_step=step) + elif ((isinstance(value, list) and len(value) > 0) + or (isinstance(value, np.ndarray) and value.size > 0)): + + # Must be video + if isinstance(value, np.ndarray) and value.ndim == 5: + self._file_writer.add_video( + full_attr, value, global_step=step, fps=20) + continue + + try: + self._file_writer.add_histogram( + full_attr, value, global_step=step) + # In case TensorboardX still doesn't think it's a valid value + # (e.g. `[[]]`), warn and move on. + except (ValueError, TypeError): + if log_once("invalid_tbx_value"): + warnings.warn( + "You are trying to log an invalid value ({}={}) " + "via {}!".format(full_attr, value, + type(self).__name__)) + self._file_writer.flush() + + def finish_training(self, error: bool = False, **info): + self._file_writer.close() diff --git a/python/ray/util/sgd/v2/constants.py b/python/ray/util/sgd/v2/constants.py index 1f31e8a4a..1517c3537 100644 --- a/python/ray/util/sgd/v2/constants.py +++ b/python/ray/util/sgd/v2/constants.py @@ -1,6 +1,6 @@ -# Autofilled sgd.report() metrics. Keys should be consistent with Tune. from pathlib import Path +# Autofilled sgd.report() metrics. Keys should be consistent with Tune. TIMESTAMP = "_timestamp" TIME_THIS_ITER_S = "_time_this_iter_s" TRAINING_ITERATION = "_training_iteration" diff --git a/python/ray/util/sgd/v2/examples/train_linear.py b/python/ray/util/sgd/v2/examples/train_linear.py index 937832065..de6f083bf 100644 --- a/python/ray/util/sgd/v2/examples/train_linear.py +++ b/python/ray/util/sgd/v2/examples/train_linear.py @@ -1,11 +1,11 @@ import argparse import numpy as np -import ray.util.sgd.v2 as sgd import torch import torch.nn as nn +import ray.util.sgd.v2 as sgd from ray.util.sgd.v2 import Trainer, TorchConfig -from ray.util.sgd.v2.callbacks import JsonLoggerCallback +from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DistributedSampler @@ -92,7 +92,10 @@ def train_linear(num_workers=1): config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": 3} trainer.start() results = trainer.run( - train_func, config, callbacks=[JsonLoggerCallback("./sgd_results")]) + train_func, + config, + callbacks=[JsonLoggerCallback(), + TBXLoggerCallback()]) trainer.shutdown() print(results) @@ -110,7 +113,7 @@ if __name__ == "__main__": "--num-workers", "-n", type=int, - default=1, + default=2, help="Sets number of workers for training.") parser.add_argument( "--smoke-test", diff --git a/python/ray/util/sgd/v2/tests/test_callbacks.py b/python/ray/util/sgd/v2/tests/test_callbacks.py index 9d733b66f..75e6d0cc3 100644 --- a/python/ray/util/sgd/v2/tests/test_callbacks.py +++ b/python/ray/util/sgd/v2/tests/test_callbacks.py @@ -3,6 +3,8 @@ import os import shutil import tempfile import json +import glob +from collections import defaultdict import ray import ray.util.sgd.v2 as sgd @@ -10,10 +12,16 @@ from ray.util.sgd.v2 import Trainer from ray.util.sgd.v2.constants import ( TRAINING_ITERATION, DETAILED_AUTOFILLED_KEYS, BASIC_AUTOFILLED_KEYS, ENABLE_DETAILED_AUTOFILLED_METRICS_ENV) -from ray.util.sgd.v2.callbacks import JsonLoggerCallback +from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface from ray.util.sgd.v2.worker_group import WorkerGroup +try: + from tensorflow.python.summary.summary_iterator \ + import summary_iterator +except ImportError: + summary_iterator = None + @pytest.fixture def ray_start_4_cpus(): @@ -78,15 +86,17 @@ def test_json(ray_start_4_cpus, make_temp_dir, workers_to_log, detailed, # if None, use default value callback = JsonLoggerCallback( make_temp_dir, workers_to_log=workers_to_log) - assert str( - callback.log_path.name) == JsonLoggerCallback._default_filename else: callback = JsonLoggerCallback( make_temp_dir, filename=filename, workers_to_log=workers_to_log) - assert str(callback.log_path.name) == filename trainer = Trainer(config, num_workers=num_workers) trainer.start() trainer.run(train_func, callbacks=[callback]) + if filename is None: + assert str( + callback.log_path.name) == JsonLoggerCallback._default_filename + else: + assert str(callback.log_path.name) == filename with open(callback.log_path, "r") as f: log = json.load(f) @@ -110,3 +120,39 @@ def test_json(ray_start_4_cpus, make_temp_dir, workers_to_log, detailed, assert all( all(not any(key in worker for key in DETAILED_AUTOFILLED_KEYS) for worker in element) for element in log) + + +def _validate_tbx_result(events_dir): + events_file = list(glob.glob(f"{events_dir}/events*"))[0] + results = defaultdict(list) + for event in summary_iterator(events_file): + for v in event.summary.value: + assert v.tag.startswith("ray/sgd") + results[v.tag[8:]].append(v.simple_value) + + assert len(results["episode_reward_mean"]) == 3 + assert [int(res) for res in results["episode_reward_mean"]] == [4, 5, 6] + assert len(results["score"]) == 1 + assert len(results["hello/world"]) == 1 + + +@pytest.mark.skipif( + summary_iterator is None, reason="tensorboard is not installed") +def test_TBX(ray_start_4_cpus, make_temp_dir): + config = TestConfig() + + temp_dir = make_temp_dir + num_workers = 4 + + def train_func(): + sgd.report(episode_reward_mean=4) + sgd.report(episode_reward_mean=5) + sgd.report(episode_reward_mean=6, score=[1, 2, 3], hello={"world": 1}) + return 1 + + callback = TBXLoggerCallback(temp_dir) + trainer = Trainer(config, num_workers=num_workers) + trainer.start() + trainer.run(train_func, callbacks=[callback]) + + _validate_tbx_result(temp_dir) diff --git a/python/ray/util/sgd/v2/tests/test_trainer.py b/python/ray/util/sgd/v2/tests/test_trainer.py index 768766618..ef589c386 100644 --- a/python/ray/util/sgd/v2/tests/test_trainer.py +++ b/python/ray/util/sgd/v2/tests/test_trainer.py @@ -7,7 +7,6 @@ import ray import ray.util.sgd.v2 as sgd import tensorflow as tf import torch - from ray.util.sgd.v2 import Trainer from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface, \ BackendExecutor diff --git a/python/ray/util/sgd/v2/trainer.py b/python/ray/util/sgd/v2/trainer.py index 4f23146ad..b049c8113 100644 --- a/python/ray/util/sgd/v2/trainer.py +++ b/python/ray/util/sgd/v2/trainer.py @@ -158,7 +158,7 @@ class Trainer: finished_with_errors = False for callback in callbacks: - callback.start_training() + callback.start_training(logdir=self.logdir) try: iterator = self.run_iterator(