mirror of
https://github.com/vale981/ray
synced 2025-03-06 02:21:39 -05:00
[SGDv2] Tensorboard Callback (#17824)
* [SGD] save checkpoints to disk * fix test; add logs * Extend SGDv2 callback API * Move json file creation to JsonLoggerCallback * TBXLoggerCallback * Simplify, fix linear example * rename log_dir to logdir for consistency with tune * Add test * Fix * Break up logging classes * Fix error * Update type hint for results * Refactor Co-authored-by: Matthew Deng <matthew.j.deng@gmail.com>
This commit is contained in:
parent
95b5ad12ba
commit
714193ce6f
11 changed files with 488 additions and 258 deletions
|
@ -9,8 +9,7 @@ import inspect
|
|||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from threading import Thread
|
||||
from typing import Optional
|
||||
|
@ -20,6 +19,9 @@ import ray
|
|||
import psutil
|
||||
|
||||
from ray.util.ml_utils.json import SafeFallbackEncoder # noqa
|
||||
from ray.util.ml_utils.dict import merge_dicts, deep_update, flatten_dict, \
|
||||
unflatten_dict, unflatten_list_dict, \
|
||||
unflattened_lookup # noqa
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -220,211 +222,6 @@ def is_nan_or_inf(value):
|
|||
return np.isnan(value) or np.isinf(value)
|
||||
|
||||
|
||||
def merge_dicts(d1, d2):
|
||||
"""
|
||||
Args:
|
||||
d1 (dict): Dict 1.
|
||||
d2 (dict): Dict 2.
|
||||
|
||||
Returns:
|
||||
dict: A new dict that is d1 and d2 deep merged.
|
||||
"""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(original,
|
||||
new_dict,
|
||||
new_keys_allowed=False,
|
||||
allow_new_subkey_list=None,
|
||||
override_all_if_type_changes=None):
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the allow_new_subkey_list, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
allow_new_subkey_list (Optional[List[str]]): List of keys that
|
||||
correspond to dict values where new subkeys can be introduced.
|
||||
This is only at the top level.
|
||||
override_all_if_type_changes(Optional[List[str]]): List of top level
|
||||
keys with value=dict, for which we always simply override the
|
||||
entire value (dict), iff the "type" key in that value dict changes.
|
||||
"""
|
||||
allow_new_subkey_list = allow_new_subkey_list or []
|
||||
override_all_if_type_changes = override_all_if_type_changes or []
|
||||
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
|
||||
# Both orginal value and new one are dicts.
|
||||
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
||||
# Check old type vs old one. If different, override entire value.
|
||||
if k in override_all_if_type_changes and \
|
||||
"type" in value and "type" in original[k] and \
|
||||
value["type"] != original[k]["type"]:
|
||||
original[k] = value
|
||||
# Allowed key -> ok to add new subkeys.
|
||||
elif k in allow_new_subkey_list:
|
||||
deep_update(original[k], value, True)
|
||||
# Non-allowed key.
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed)
|
||||
# Original value not a dict OR new value not a dict:
|
||||
# Override entire value.
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
def flatten_dict(dt: Dict,
|
||||
delimiter: str = "/",
|
||||
prevent_delimiter: bool = False,
|
||||
flatten_list: bool = False):
|
||||
"""Flatten dict.
|
||||
|
||||
Output and input are of the same dict type.
|
||||
Input dict remains the same after the operation.
|
||||
"""
|
||||
|
||||
def _raise_delimiter_exception():
|
||||
raise ValueError(
|
||||
f"Found delimiter `{delimiter}` in key when trying to flatten "
|
||||
f"array. Please avoid using the delimiter in your specification.")
|
||||
|
||||
dt = copy.copy(dt)
|
||||
if prevent_delimiter and any(delimiter in key for key in dt):
|
||||
# Raise if delimiter is any of the keys
|
||||
_raise_delimiter_exception()
|
||||
|
||||
while_check = (dict, list) if flatten_list else dict
|
||||
|
||||
while any(isinstance(v, while_check) for v in dt.values()):
|
||||
remove = []
|
||||
add = {}
|
||||
for key, value in dt.items():
|
||||
if isinstance(value, dict):
|
||||
for subkey, v in value.items():
|
||||
if prevent_delimiter and delimiter in subkey:
|
||||
# Raise if delimiter is in any of the subkeys
|
||||
_raise_delimiter_exception()
|
||||
|
||||
add[delimiter.join([key, str(subkey)])] = v
|
||||
remove.append(key)
|
||||
elif flatten_list and isinstance(value, list):
|
||||
for i, v in enumerate(value):
|
||||
if prevent_delimiter and delimiter in subkey:
|
||||
# Raise if delimiter is in any of the subkeys
|
||||
_raise_delimiter_exception()
|
||||
|
||||
add[delimiter.join([key, str(i)])] = v
|
||||
remove.append(key)
|
||||
|
||||
dt.update(add)
|
||||
for k in remove:
|
||||
del dt[k]
|
||||
return dt
|
||||
|
||||
|
||||
def unflatten_dict(dt, delimiter="/"):
|
||||
"""Unflatten dict. Does not support unflattening lists."""
|
||||
dict_type = type(dt)
|
||||
out = dict_type()
|
||||
for key, val in dt.items():
|
||||
path = key.split(delimiter)
|
||||
item = out
|
||||
for k in path[:-1]:
|
||||
item = item.setdefault(k, dict_type())
|
||||
if not isinstance(item, dict_type):
|
||||
raise TypeError(
|
||||
f"Cannot unflatten dict due the key '{key}' "
|
||||
f"having a parent key '{k}', which value is not "
|
||||
f"of type {dict_type} (got {type(item)}). "
|
||||
"Change the key names to resolve the conflict.")
|
||||
item[path[-1]] = val
|
||||
return out
|
||||
|
||||
|
||||
def unflatten_list_dict(dt, delimiter="/"):
|
||||
"""Unflatten nested dict and list.
|
||||
|
||||
This function now has some limitations:
|
||||
(1) The keys of dt must be str.
|
||||
(2) If unflattened dt (the result) contains list, the index order must be
|
||||
ascending when accessing dt. Otherwise, this function will throw
|
||||
AssertionError.
|
||||
(3) The unflattened dt (the result) shouldn't contain dict with number
|
||||
keys.
|
||||
|
||||
Be careful to use this function. If you want to improve this function,
|
||||
please also improve the unit test. See #14487 for more details.
|
||||
|
||||
Args:
|
||||
dt (dict): Flattened dictionary that is originally nested by multiple
|
||||
list and dict.
|
||||
delimiter (str): Delimiter of keys.
|
||||
|
||||
Example:
|
||||
>>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
|
||||
>>> unflatten_list_dict(dt)
|
||||
{'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
|
||||
"""
|
||||
out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() \
|
||||
else type(dt)
|
||||
out = out_type()
|
||||
for key, val in dt.items():
|
||||
path = key.split(delimiter)
|
||||
|
||||
item = out
|
||||
for i, k in enumerate(path[:-1]):
|
||||
next_type = list if path[i + 1].isdigit() else dict
|
||||
if isinstance(item, dict):
|
||||
item = item.setdefault(k, next_type())
|
||||
elif isinstance(item, list):
|
||||
if int(k) >= len(item):
|
||||
item.append(next_type())
|
||||
assert int(k) == len(item) - 1
|
||||
item = item[int(k)]
|
||||
|
||||
if isinstance(item, dict):
|
||||
item[path[-1]] = val
|
||||
elif isinstance(item, list):
|
||||
item.append(val)
|
||||
assert int(path[-1]) == len(item) - 1
|
||||
return out
|
||||
|
||||
|
||||
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):
|
||||
"""
|
||||
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
|
||||
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
|
||||
"""
|
||||
if flat_key in lookup:
|
||||
return lookup[flat_key]
|
||||
keys = deque(flat_key.split(delimiter))
|
||||
base = lookup
|
||||
while keys:
|
||||
key = keys.popleft()
|
||||
try:
|
||||
if isinstance(base, Mapping):
|
||||
base = base[key]
|
||||
elif isinstance(base, Sequence):
|
||||
base = base[int(key)]
|
||||
else:
|
||||
raise KeyError()
|
||||
except KeyError as e:
|
||||
if "default" in kwargs:
|
||||
return kwargs["default"]
|
||||
raise e
|
||||
return base
|
||||
|
||||
|
||||
def _to_pinnable(obj):
|
||||
"""Converts obj to a form that can be pinned in object store memory.
|
||||
|
||||
|
|
216
python/ray/util/ml_utils/dict.py
Normal file
216
python/ray/util/ml_utils/dict.py
Normal file
|
@ -0,0 +1,216 @@
|
|||
from typing import Dict, List, Union, Optional, TypeVar
|
||||
import copy
|
||||
from collections import deque
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def merge_dicts(d1: dict, d2: dict) -> dict:
|
||||
"""
|
||||
Args:
|
||||
d1 (dict): Dict 1.
|
||||
d2 (dict): Dict 2.
|
||||
|
||||
Returns:
|
||||
dict: A new dict that is d1 and d2 deep merged.
|
||||
"""
|
||||
merged = copy.deepcopy(d1)
|
||||
deep_update(merged, d2, True, [])
|
||||
return merged
|
||||
|
||||
|
||||
def deep_update(
|
||||
original: dict,
|
||||
new_dict: dict,
|
||||
new_keys_allowed: str = False,
|
||||
allow_new_subkey_list: Optional[List[str]] = None,
|
||||
override_all_if_type_changes: Optional[List[str]] = None) -> dict:
|
||||
"""Updates original dict with values from new_dict recursively.
|
||||
|
||||
If new key is introduced in new_dict, then if new_keys_allowed is not
|
||||
True, an error will be thrown. Further, for sub-dicts, if the key is
|
||||
in the allow_new_subkey_list, then new subkeys can be introduced.
|
||||
|
||||
Args:
|
||||
original (dict): Dictionary with default values.
|
||||
new_dict (dict): Dictionary with values to be updated
|
||||
new_keys_allowed (bool): Whether new keys are allowed.
|
||||
allow_new_subkey_list (Optional[List[str]]): List of keys that
|
||||
correspond to dict values where new subkeys can be introduced.
|
||||
This is only at the top level.
|
||||
override_all_if_type_changes(Optional[List[str]]): List of top level
|
||||
keys with value=dict, for which we always simply override the
|
||||
entire value (dict), iff the "type" key in that value dict changes.
|
||||
"""
|
||||
allow_new_subkey_list = allow_new_subkey_list or []
|
||||
override_all_if_type_changes = override_all_if_type_changes or []
|
||||
|
||||
for k, value in new_dict.items():
|
||||
if k not in original and not new_keys_allowed:
|
||||
raise Exception("Unknown config parameter `{}` ".format(k))
|
||||
|
||||
# Both orginal value and new one are dicts.
|
||||
if isinstance(original.get(k), dict) and isinstance(value, dict):
|
||||
# Check old type vs old one. If different, override entire value.
|
||||
if k in override_all_if_type_changes and \
|
||||
"type" in value and "type" in original[k] and \
|
||||
value["type"] != original[k]["type"]:
|
||||
original[k] = value
|
||||
# Allowed key -> ok to add new subkeys.
|
||||
elif k in allow_new_subkey_list:
|
||||
deep_update(original[k], value, True)
|
||||
# Non-allowed key.
|
||||
else:
|
||||
deep_update(original[k], value, new_keys_allowed)
|
||||
# Original value not a dict OR new value not a dict:
|
||||
# Override entire value.
|
||||
else:
|
||||
original[k] = value
|
||||
return original
|
||||
|
||||
|
||||
def flatten_dict(dt: Dict,
|
||||
delimiter: str = "/",
|
||||
prevent_delimiter: bool = False,
|
||||
flatten_list: bool = False):
|
||||
"""Flatten dict.
|
||||
|
||||
Output and input are of the same dict type.
|
||||
Input dict remains the same after the operation.
|
||||
"""
|
||||
|
||||
def _raise_delimiter_exception():
|
||||
raise ValueError(
|
||||
f"Found delimiter `{delimiter}` in key when trying to flatten "
|
||||
f"array. Please avoid using the delimiter in your specification.")
|
||||
|
||||
dt = copy.copy(dt)
|
||||
if prevent_delimiter and any(delimiter in key for key in dt):
|
||||
# Raise if delimiter is any of the keys
|
||||
_raise_delimiter_exception()
|
||||
|
||||
while_check = (dict, list) if flatten_list else dict
|
||||
|
||||
while any(isinstance(v, while_check) for v in dt.values()):
|
||||
remove = []
|
||||
add = {}
|
||||
for key, value in dt.items():
|
||||
if isinstance(value, dict):
|
||||
for subkey, v in value.items():
|
||||
if prevent_delimiter and delimiter in subkey:
|
||||
# Raise if delimiter is in any of the subkeys
|
||||
_raise_delimiter_exception()
|
||||
|
||||
add[delimiter.join([key, str(subkey)])] = v
|
||||
remove.append(key)
|
||||
elif flatten_list and isinstance(value, list):
|
||||
for i, v in enumerate(value):
|
||||
if prevent_delimiter and delimiter in subkey:
|
||||
# Raise if delimiter is in any of the subkeys
|
||||
_raise_delimiter_exception()
|
||||
|
||||
add[delimiter.join([key, str(i)])] = v
|
||||
remove.append(key)
|
||||
|
||||
dt.update(add)
|
||||
for k in remove:
|
||||
del dt[k]
|
||||
return dt
|
||||
|
||||
|
||||
def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
|
||||
"""Unflatten dict. Does not support unflattening lists."""
|
||||
dict_type = type(dt)
|
||||
out = dict_type()
|
||||
for key, val in dt.items():
|
||||
path = key.split(delimiter)
|
||||
item = out
|
||||
for k in path[:-1]:
|
||||
item = item.setdefault(k, dict_type())
|
||||
if not isinstance(item, dict_type):
|
||||
raise TypeError(
|
||||
f"Cannot unflatten dict due the key '{key}' "
|
||||
f"having a parent key '{k}', which value is not "
|
||||
f"of type {dict_type} (got {type(item)}). "
|
||||
"Change the key names to resolve the conflict.")
|
||||
item[path[-1]] = val
|
||||
return out
|
||||
|
||||
|
||||
def unflatten_list_dict(dt: Dict[str, T],
|
||||
delimiter: str = "/") -> Dict[str, T]:
|
||||
"""Unflatten nested dict and list.
|
||||
|
||||
This function now has some limitations:
|
||||
(1) The keys of dt must be str.
|
||||
(2) If unflattened dt (the result) contains list, the index order must be
|
||||
ascending when accessing dt. Otherwise, this function will throw
|
||||
AssertionError.
|
||||
(3) The unflattened dt (the result) shouldn't contain dict with number
|
||||
keys.
|
||||
|
||||
Be careful to use this function. If you want to improve this function,
|
||||
please also improve the unit test. See #14487 for more details.
|
||||
|
||||
Args:
|
||||
dt (dict): Flattened dictionary that is originally nested by multiple
|
||||
list and dict.
|
||||
delimiter (str): Delimiter of keys.
|
||||
|
||||
Example:
|
||||
>>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
|
||||
>>> unflatten_list_dict(dt)
|
||||
{'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
|
||||
"""
|
||||
out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() \
|
||||
else type(dt)
|
||||
out = out_type()
|
||||
for key, val in dt.items():
|
||||
path = key.split(delimiter)
|
||||
|
||||
item = out
|
||||
for i, k in enumerate(path[:-1]):
|
||||
next_type = list if path[i + 1].isdigit() else dict
|
||||
if isinstance(item, dict):
|
||||
item = item.setdefault(k, next_type())
|
||||
elif isinstance(item, list):
|
||||
if int(k) >= len(item):
|
||||
item.append(next_type())
|
||||
assert int(k) == len(item) - 1
|
||||
item = item[int(k)]
|
||||
|
||||
if isinstance(item, dict):
|
||||
item[path[-1]] = val
|
||||
elif isinstance(item, list):
|
||||
item.append(val)
|
||||
assert int(path[-1]) == len(item) - 1
|
||||
return out
|
||||
|
||||
|
||||
def unflattened_lookup(flat_key: str,
|
||||
lookup: Union[Mapping, Sequence],
|
||||
delimiter: str = "/",
|
||||
**kwargs) -> Union[Mapping, Sequence]:
|
||||
"""
|
||||
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
|
||||
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
|
||||
"""
|
||||
if flat_key in lookup:
|
||||
return lookup[flat_key]
|
||||
keys = deque(flat_key.split(delimiter))
|
||||
base = lookup
|
||||
while keys:
|
||||
key = keys.popleft()
|
||||
try:
|
||||
if isinstance(base, Mapping):
|
||||
base = base[key]
|
||||
elif isinstance(base, Sequence):
|
||||
base = base[int(key)]
|
||||
else:
|
||||
raise KeyError()
|
||||
except KeyError as e:
|
||||
if "default" in kwargs:
|
||||
return kwargs["default"]
|
||||
raise e
|
||||
return base
|
|
@ -50,7 +50,8 @@ class BackendExecutor:
|
|||
generated.
|
||||
|
||||
Attributes:
|
||||
logdir (Path): Path to the file directory where logs will be persisted.
|
||||
logdir (Path): Path to the file directory where logs will be
|
||||
persisted.
|
||||
latest_run_dir (Optional[Path]): Path to the file directory for the
|
||||
latest run. Configured through ``start_training``.
|
||||
latest_checkpoint_dir (Optional[Path]): Path to the file directory for
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from ray.util.sgd.v2.callbacks.callback import SGDCallback
|
||||
from ray.util.sgd.v2.callbacks.logging import JsonLoggerCallback
|
||||
from ray.util.sgd.v2.callbacks.logging import (JsonLoggerCallback,
|
||||
TBXLoggerCallback)
|
||||
|
||||
__all__ = ["SGDCallback", "JsonLoggerCallback"]
|
||||
__all__ = ["SGDCallback", "JsonLoggerCallback", "TBXLoggerCallback"]
|
||||
|
|
|
@ -1,21 +1,36 @@
|
|||
import abc
|
||||
from typing import List, Optional, Dict
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class SGDCallback(metaclass=abc.ABCMeta):
|
||||
"""Abstract SGD callback class."""
|
||||
|
||||
def handle_result(self, results: Optional[List[Dict]]):
|
||||
"""Called every time sgd.report() is called."""
|
||||
def handle_result(self, results: List[Dict], **info):
|
||||
"""Called every time sgd.report() is called.
|
||||
|
||||
Args:
|
||||
results (List[Dict]): List of results from the training
|
||||
function. Each value in the list corresponds to the output of
|
||||
the training function from each worker.
|
||||
**info: kwargs dict for forward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
def start_training(self):
|
||||
"""Called once on training start."""
|
||||
def start_training(self, logdir: str, **info):
|
||||
"""Called once on training start.
|
||||
|
||||
Args:
|
||||
logdir (str): Path to the file directory where logs
|
||||
should be persisted.
|
||||
**info: kwargs dict for forward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
||||
def finish_training(self, error: bool = False):
|
||||
def finish_training(self, error: bool = False, **info):
|
||||
"""Called once after training is over.
|
||||
|
||||
Args:
|
||||
error (bool): If True, there was an exception during training."""
|
||||
error (bool): If True, there was an exception during training.
|
||||
**info: kwargs dict for forward compatibility.
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -1,22 +1,52 @@
|
|||
from typing import Iterable, List, Optional, Dict, Union
|
||||
from typing import Iterable, List, Optional, Dict, Set, Tuple, Union
|
||||
import abc
|
||||
|
||||
import warnings
|
||||
import logging
|
||||
import numpy as np
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from ray.util.debug import log_once
|
||||
from ray.util.ml_utils.dict import flatten_dict
|
||||
from ray.util.ml_utils.json import SafeFallbackEncoder
|
||||
from ray.util.sgd.v2.callbacks import SGDCallback
|
||||
from ray.util.sgd.v2.constants import RESULT_FILE_JSON
|
||||
from ray.util.sgd.v2.constants import (RESULT_FILE_JSON, TRAINING_ITERATION,
|
||||
TIME_TOTAL_S, TIMESTAMP, PID)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta):
|
||||
class SGDLogdirMixin:
|
||||
def start_training(self, logdir: str, **info):
|
||||
if self._logdir:
|
||||
logdir_path = Path(self._logdir)
|
||||
else:
|
||||
logdir_path = Path(logdir)
|
||||
|
||||
if not logdir_path.is_dir():
|
||||
raise ValueError(f"logdir '{logdir}' must be a directory.")
|
||||
|
||||
self._logdir_path = logdir_path
|
||||
|
||||
@property
|
||||
def logdir(self) -> Optional[Path]:
|
||||
"""Path to currently used logging directory."""
|
||||
if not hasattr(self, "_logdir_path"):
|
||||
return Path(self._logdir)
|
||||
return Path(self._logdir_path)
|
||||
|
||||
|
||||
class SGDSingleFileLoggingCallback(
|
||||
SGDLogdirMixin, SGDCallback, metaclass=abc.ABCMeta):
|
||||
"""Abstract SGD logging callback class.
|
||||
|
||||
Args:
|
||||
logdir (str): Path to directory where the results file should be.
|
||||
filename (str|None): Filename in logdir to save results to.
|
||||
logdir (Optional[str]): Path to directory where the results file
|
||||
should be. If None, will be set by the Trainer.
|
||||
filename (Optional[str]): Filename in logdir to save results to.
|
||||
workers_to_log (int|List[int]|None): Worker indices to log.
|
||||
If None, will log all workers.
|
||||
If None, will log all workers. By default, will log the
|
||||
worker with index 0.
|
||||
"""
|
||||
|
||||
# Defining it like this ensures it will be overwritten
|
||||
|
@ -24,18 +54,15 @@ class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta):
|
|||
_default_filename: Union[str, Path]
|
||||
|
||||
def __init__(self,
|
||||
logdir: str,
|
||||
logdir: Optional[str] = None,
|
||||
filename: Optional[str] = None,
|
||||
workers_to_log: Optional[Union[int, List[int]]] = 0) -> None:
|
||||
logdir_path = Path(logdir)
|
||||
self._logdir = logdir
|
||||
self._filename = filename
|
||||
self._workers_to_log = self._validate_workers_to_log(workers_to_log)
|
||||
self._log_path = None
|
||||
|
||||
if not logdir_path.is_dir():
|
||||
raise ValueError(f"logdir '{logdir}' must be a directory.")
|
||||
|
||||
if filename is None:
|
||||
filename = self._default_filename
|
||||
|
||||
self._log_path = logdir_path.joinpath(Path(filename))
|
||||
def _validate_workers_to_log(self, workers_to_log) -> List[int]:
|
||||
if isinstance(workers_to_log, int):
|
||||
workers_to_log = [workers_to_log]
|
||||
|
||||
|
@ -46,35 +73,58 @@ class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta):
|
|||
if not all(isinstance(worker, int) for worker in workers_to_log):
|
||||
raise TypeError(
|
||||
"All elements of workers_to_log must be integers.")
|
||||
if len(workers_to_log) < 1:
|
||||
raise ValueError(
|
||||
"At least one worker must be specified in workers_to_log.")
|
||||
return workers_to_log
|
||||
|
||||
self._workers_to_log = workers_to_log
|
||||
def _create_log_path(self, logdir_path: Path, filename: Path) -> Path:
|
||||
if not filename:
|
||||
raise ValueError("filename cannot be None or empty.")
|
||||
return logdir_path.joinpath(Path(filename))
|
||||
|
||||
def start_training(self, logdir: str, **info):
|
||||
super().start_training(logdir, **info)
|
||||
|
||||
if not self._filename:
|
||||
filename = self._default_filename
|
||||
else:
|
||||
filename = self._filename
|
||||
|
||||
self._log_path = self._create_log_path(self.logdir, filename)
|
||||
|
||||
@property
|
||||
def log_path(self) -> Path:
|
||||
"""Path to the log file."""
|
||||
return self._log_path
|
||||
def log_path(self) -> Optional[Path]:
|
||||
"""Path to the log file.
|
||||
|
||||
def start_training(self):
|
||||
# Create a JSON file with an empty list
|
||||
# that will be latter appended to
|
||||
with open(self._log_path, "w") as f:
|
||||
json.dump([], f, cls=SafeFallbackEncoder)
|
||||
Will be None before `start_training` is called for the first time.
|
||||
"""
|
||||
return self._log_path
|
||||
|
||||
|
||||
class JsonLoggerCallback(SGDSingleFileLoggingCallback):
|
||||
"""Logs SGD results in json format.
|
||||
|
||||
Args:
|
||||
logdir (str): Path to directory where the results file should be.
|
||||
filename (str|None): Filename in logdir to save results to.
|
||||
Defaults to "results.json".
|
||||
logdir (Optional[str]): Path to directory where the results file
|
||||
should be. If None, will be set by the Trainer.
|
||||
filename (Optional[str]): Filename in logdir to save results to.
|
||||
workers_to_log (int|List[int]|None): Worker indices to log.
|
||||
If None, will log all workers.
|
||||
If None, will log all workers. By default, will log the
|
||||
worker with index 0.
|
||||
"""
|
||||
|
||||
_default_filename: Union[str, Path] = RESULT_FILE_JSON
|
||||
|
||||
def handle_result(self, results: Optional[List[Dict]]):
|
||||
def start_training(self, logdir: str, **info):
|
||||
super().start_training(logdir, **info)
|
||||
|
||||
# Create a JSON file with an empty list
|
||||
# that will be latter appended to
|
||||
with open(self._log_path, "w") as f:
|
||||
json.dump([], f, cls=SafeFallbackEncoder)
|
||||
|
||||
def handle_result(self, results: List[Dict], **info):
|
||||
if self._workers_to_log is None or results is None:
|
||||
results_to_log = results
|
||||
else:
|
||||
|
@ -87,3 +137,105 @@ class JsonLoggerCallback(SGDSingleFileLoggingCallback):
|
|||
f.seek(0)
|
||||
json.dump(
|
||||
loaded_results + [results_to_log], f, cls=SafeFallbackEncoder)
|
||||
|
||||
|
||||
class SGDSingleWorkerLoggingCallback(
|
||||
SGDLogdirMixin, SGDCallback, metaclass=abc.ABCMeta):
|
||||
"""Abstract SGD logging callback class.
|
||||
|
||||
Allows only for single-worker logging.
|
||||
|
||||
Args:
|
||||
logdir (Optional[str]): Path to directory where the results file
|
||||
should be. If None, will be set by the Trainer.
|
||||
worker_to_log (int): Worker index to log. By default, will log the
|
||||
worker with index 0.
|
||||
"""
|
||||
|
||||
def __init__(self, logdir: Optional[str] = None,
|
||||
worker_to_log: int = 0) -> None:
|
||||
self._logdir = logdir
|
||||
self._workers_to_log = self._validate_worker_to_log(worker_to_log)
|
||||
self._log_path = None
|
||||
|
||||
def _validate_worker_to_log(self, worker_to_log) -> int:
|
||||
if isinstance(worker_to_log, Iterable):
|
||||
worker_to_log = list(worker_to_log)
|
||||
if len(worker_to_log) > 1:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} only supports logging "
|
||||
"from a single worker.")
|
||||
elif len(worker_to_log) < 1:
|
||||
raise ValueError(
|
||||
"At least one worker must be specified in workers_to_log.")
|
||||
worker_to_log = worker_to_log[0]
|
||||
if not isinstance(worker_to_log, int):
|
||||
raise TypeError("workers_to_log must be an integer.")
|
||||
return worker_to_log
|
||||
|
||||
|
||||
class TBXLoggerCallback(SGDSingleWorkerLoggingCallback):
|
||||
"""Logs SGD results in TensorboardX format.
|
||||
|
||||
Args:
|
||||
logdir (Optional[str]): Path to directory where the results file
|
||||
should be. If None, will be set by the Trainer.
|
||||
worker_to_log (int): Worker index to log. By default, will log the
|
||||
worker with index 0.
|
||||
"""
|
||||
|
||||
VALID_SUMMARY_TYPES: Tuple[type] = (int, float, np.float32, np.float64,
|
||||
np.int32, np.int64)
|
||||
IGNORE_KEYS: Set[str] = {PID, TIMESTAMP, TIME_TOTAL_S, TRAINING_ITERATION}
|
||||
|
||||
def start_training(self, logdir: str, **info):
|
||||
super().start_training(logdir, **info)
|
||||
|
||||
try:
|
||||
from tensorboardX import SummaryWriter
|
||||
except ImportError:
|
||||
if log_once("tbx-install"):
|
||||
warnings.warn(
|
||||
"pip install 'tensorboardX' to see TensorBoard files.")
|
||||
raise
|
||||
|
||||
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
|
||||
|
||||
def handle_result(self, results: List[Dict], **info):
|
||||
result = results[self._workers_to_log]
|
||||
step = result[TRAINING_ITERATION]
|
||||
result = {k: v for k, v in result.items() if k not in self.IGNORE_KEYS}
|
||||
flat_result = flatten_dict(result, delimiter="/")
|
||||
path = ["ray", "sgd"]
|
||||
|
||||
# same logic as in ray.tune.logger.TBXLogger
|
||||
for attr, value in flat_result.items():
|
||||
full_attr = "/".join(path + [attr])
|
||||
if (isinstance(value, self.VALID_SUMMARY_TYPES)
|
||||
and not np.isnan(value)):
|
||||
self._file_writer.add_scalar(
|
||||
full_attr, value, global_step=step)
|
||||
elif ((isinstance(value, list) and len(value) > 0)
|
||||
or (isinstance(value, np.ndarray) and value.size > 0)):
|
||||
|
||||
# Must be video
|
||||
if isinstance(value, np.ndarray) and value.ndim == 5:
|
||||
self._file_writer.add_video(
|
||||
full_attr, value, global_step=step, fps=20)
|
||||
continue
|
||||
|
||||
try:
|
||||
self._file_writer.add_histogram(
|
||||
full_attr, value, global_step=step)
|
||||
# In case TensorboardX still doesn't think it's a valid value
|
||||
# (e.g. `[[]]`), warn and move on.
|
||||
except (ValueError, TypeError):
|
||||
if log_once("invalid_tbx_value"):
|
||||
warnings.warn(
|
||||
"You are trying to log an invalid value ({}={}) "
|
||||
"via {}!".format(full_attr, value,
|
||||
type(self).__name__))
|
||||
self._file_writer.flush()
|
||||
|
||||
def finish_training(self, error: bool = False, **info):
|
||||
self._file_writer.close()
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Autofilled sgd.report() metrics. Keys should be consistent with Tune.
|
||||
from pathlib import Path
|
||||
|
||||
# Autofilled sgd.report() metrics. Keys should be consistent with Tune.
|
||||
TIMESTAMP = "_timestamp"
|
||||
TIME_THIS_ITER_S = "_time_this_iter_s"
|
||||
TRAINING_ITERATION = "_training_iteration"
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import ray.util.sgd.v2 as sgd
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import ray.util.sgd.v2 as sgd
|
||||
from ray.util.sgd.v2 import Trainer, TorchConfig
|
||||
from ray.util.sgd.v2.callbacks import JsonLoggerCallback
|
||||
from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
from torch.utils.data import DistributedSampler
|
||||
|
||||
|
@ -92,7 +92,10 @@ def train_linear(num_workers=1):
|
|||
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": 3}
|
||||
trainer.start()
|
||||
results = trainer.run(
|
||||
train_func, config, callbacks=[JsonLoggerCallback("./sgd_results")])
|
||||
train_func,
|
||||
config,
|
||||
callbacks=[JsonLoggerCallback(),
|
||||
TBXLoggerCallback()])
|
||||
trainer.shutdown()
|
||||
|
||||
print(results)
|
||||
|
@ -110,7 +113,7 @@ if __name__ == "__main__":
|
|||
"--num-workers",
|
||||
"-n",
|
||||
type=int,
|
||||
default=1,
|
||||
default=2,
|
||||
help="Sets number of workers for training.")
|
||||
parser.add_argument(
|
||||
"--smoke-test",
|
||||
|
|
|
@ -3,6 +3,8 @@ import os
|
|||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
import glob
|
||||
from collections import defaultdict
|
||||
|
||||
import ray
|
||||
import ray.util.sgd.v2 as sgd
|
||||
|
@ -10,10 +12,16 @@ from ray.util.sgd.v2 import Trainer
|
|||
from ray.util.sgd.v2.constants import (
|
||||
TRAINING_ITERATION, DETAILED_AUTOFILLED_KEYS, BASIC_AUTOFILLED_KEYS,
|
||||
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV)
|
||||
from ray.util.sgd.v2.callbacks import JsonLoggerCallback
|
||||
from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface
|
||||
from ray.util.sgd.v2.worker_group import WorkerGroup
|
||||
|
||||
try:
|
||||
from tensorflow.python.summary.summary_iterator \
|
||||
import summary_iterator
|
||||
except ImportError:
|
||||
summary_iterator = None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ray_start_4_cpus():
|
||||
|
@ -78,15 +86,17 @@ def test_json(ray_start_4_cpus, make_temp_dir, workers_to_log, detailed,
|
|||
# if None, use default value
|
||||
callback = JsonLoggerCallback(
|
||||
make_temp_dir, workers_to_log=workers_to_log)
|
||||
assert str(
|
||||
callback.log_path.name) == JsonLoggerCallback._default_filename
|
||||
else:
|
||||
callback = JsonLoggerCallback(
|
||||
make_temp_dir, filename=filename, workers_to_log=workers_to_log)
|
||||
assert str(callback.log_path.name) == filename
|
||||
trainer = Trainer(config, num_workers=num_workers)
|
||||
trainer.start()
|
||||
trainer.run(train_func, callbacks=[callback])
|
||||
if filename is None:
|
||||
assert str(
|
||||
callback.log_path.name) == JsonLoggerCallback._default_filename
|
||||
else:
|
||||
assert str(callback.log_path.name) == filename
|
||||
|
||||
with open(callback.log_path, "r") as f:
|
||||
log = json.load(f)
|
||||
|
@ -110,3 +120,39 @@ def test_json(ray_start_4_cpus, make_temp_dir, workers_to_log, detailed,
|
|||
assert all(
|
||||
all(not any(key in worker for key in DETAILED_AUTOFILLED_KEYS)
|
||||
for worker in element) for element in log)
|
||||
|
||||
|
||||
def _validate_tbx_result(events_dir):
|
||||
events_file = list(glob.glob(f"{events_dir}/events*"))[0]
|
||||
results = defaultdict(list)
|
||||
for event in summary_iterator(events_file):
|
||||
for v in event.summary.value:
|
||||
assert v.tag.startswith("ray/sgd")
|
||||
results[v.tag[8:]].append(v.simple_value)
|
||||
|
||||
assert len(results["episode_reward_mean"]) == 3
|
||||
assert [int(res) for res in results["episode_reward_mean"]] == [4, 5, 6]
|
||||
assert len(results["score"]) == 1
|
||||
assert len(results["hello/world"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
summary_iterator is None, reason="tensorboard is not installed")
|
||||
def test_TBX(ray_start_4_cpus, make_temp_dir):
|
||||
config = TestConfig()
|
||||
|
||||
temp_dir = make_temp_dir
|
||||
num_workers = 4
|
||||
|
||||
def train_func():
|
||||
sgd.report(episode_reward_mean=4)
|
||||
sgd.report(episode_reward_mean=5)
|
||||
sgd.report(episode_reward_mean=6, score=[1, 2, 3], hello={"world": 1})
|
||||
return 1
|
||||
|
||||
callback = TBXLoggerCallback(temp_dir)
|
||||
trainer = Trainer(config, num_workers=num_workers)
|
||||
trainer.start()
|
||||
trainer.run(train_func, callbacks=[callback])
|
||||
|
||||
_validate_tbx_result(temp_dir)
|
||||
|
|
|
@ -7,7 +7,6 @@ import ray
|
|||
import ray.util.sgd.v2 as sgd
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from ray.util.sgd.v2 import Trainer
|
||||
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface, \
|
||||
BackendExecutor
|
||||
|
|
|
@ -158,7 +158,7 @@ class Trainer:
|
|||
finished_with_errors = False
|
||||
|
||||
for callback in callbacks:
|
||||
callback.start_training()
|
||||
callback.start_training(logdir=self.logdir)
|
||||
|
||||
try:
|
||||
iterator = self.run_iterator(
|
||||
|
|
Loading…
Add table
Reference in a new issue