[SGDv2] Tensorboard Callback (#17824)

* [SGD] save checkpoints to disk

* fix test; add logs

* Extend SGDv2 callback API

* Move json file creation to JsonLoggerCallback

* TBXLoggerCallback

* Simplify, fix linear example

* rename log_dir to logdir for consistency with tune

* Add test

* Fix

* Break up logging classes

* Fix error

* Update type hint for results

* Refactor

Co-authored-by: Matthew Deng <matthew.j.deng@gmail.com>
This commit is contained in:
Antoni Baum 2021-08-28 02:50:26 +02:00 committed by GitHub
parent 95b5ad12ba
commit 714193ce6f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 488 additions and 258 deletions

View file

@ -9,8 +9,7 @@ import inspect
import threading
import time
import uuid
from collections import defaultdict, deque
from collections.abc import Mapping, Sequence
from collections import defaultdict
from datetime import datetime
from threading import Thread
from typing import Optional
@ -20,6 +19,9 @@ import ray
import psutil
from ray.util.ml_utils.json import SafeFallbackEncoder # noqa
from ray.util.ml_utils.dict import merge_dicts, deep_update, flatten_dict, \
unflatten_dict, unflatten_list_dict, \
unflattened_lookup # noqa
logger = logging.getLogger(__name__)
@ -220,211 +222,6 @@ def is_nan_or_inf(value):
return np.isnan(value) or np.isinf(value)
def merge_dicts(d1, d2):
"""
Args:
d1 (dict): Dict 1.
d2 (dict): Dict 2.
Returns:
dict: A new dict that is d1 and d2 deep merged.
"""
merged = copy.deepcopy(d1)
deep_update(merged, d2, True, [])
return merged
def deep_update(original,
new_dict,
new_keys_allowed=False,
allow_new_subkey_list=None,
override_all_if_type_changes=None):
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the allow_new_subkey_list, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
allow_new_subkey_list (Optional[List[str]]): List of keys that
correspond to dict values where new subkeys can be introduced.
This is only at the top level.
override_all_if_type_changes(Optional[List[str]]): List of top level
keys with value=dict, for which we always simply override the
entire value (dict), iff the "type" key in that value dict changes.
"""
allow_new_subkey_list = allow_new_subkey_list or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
# Both orginal value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Allowed key -> ok to add new subkeys.
elif k in allow_new_subkey_list:
deep_update(original[k], value, True)
# Non-allowed key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
def flatten_dict(dt: Dict,
delimiter: str = "/",
prevent_delimiter: bool = False,
flatten_list: bool = False):
"""Flatten dict.
Output and input are of the same dict type.
Input dict remains the same after the operation.
"""
def _raise_delimiter_exception():
raise ValueError(
f"Found delimiter `{delimiter}` in key when trying to flatten "
f"array. Please avoid using the delimiter in your specification.")
dt = copy.copy(dt)
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
_raise_delimiter_exception()
while_check = (dict, list) if flatten_list else dict
while any(isinstance(v, while_check) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
_raise_delimiter_exception()
add[delimiter.join([key, str(subkey)])] = v
remove.append(key)
elif flatten_list and isinstance(value, list):
for i, v in enumerate(value):
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
_raise_delimiter_exception()
add[delimiter.join([key, str(i)])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt
def unflatten_dict(dt, delimiter="/"):
"""Unflatten dict. Does not support unflattening lists."""
dict_type = type(dt)
out = dict_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for k in path[:-1]:
item = item.setdefault(k, dict_type())
if not isinstance(item, dict_type):
raise TypeError(
f"Cannot unflatten dict due the key '{key}' "
f"having a parent key '{k}', which value is not "
f"of type {dict_type} (got {type(item)}). "
"Change the key names to resolve the conflict.")
item[path[-1]] = val
return out
def unflatten_list_dict(dt, delimiter="/"):
"""Unflatten nested dict and list.
This function now has some limitations:
(1) The keys of dt must be str.
(2) If unflattened dt (the result) contains list, the index order must be
ascending when accessing dt. Otherwise, this function will throw
AssertionError.
(3) The unflattened dt (the result) shouldn't contain dict with number
keys.
Be careful to use this function. If you want to improve this function,
please also improve the unit test. See #14487 for more details.
Args:
dt (dict): Flattened dictionary that is originally nested by multiple
list and dict.
delimiter (str): Delimiter of keys.
Example:
>>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
>>> unflatten_list_dict(dt)
{'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
"""
out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() \
else type(dt)
out = out_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for i, k in enumerate(path[:-1]):
next_type = list if path[i + 1].isdigit() else dict
if isinstance(item, dict):
item = item.setdefault(k, next_type())
elif isinstance(item, list):
if int(k) >= len(item):
item.append(next_type())
assert int(k) == len(item) - 1
item = item[int(k)]
if isinstance(item, dict):
item[path[-1]] = val
elif isinstance(item, list):
item.append(val)
assert int(path[-1]) == len(item) - 1
return out
def unflattened_lookup(flat_key, lookup, delimiter="/", **kwargs):
"""
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
"""
if flat_key in lookup:
return lookup[flat_key]
keys = deque(flat_key.split(delimiter))
base = lookup
while keys:
key = keys.popleft()
try:
if isinstance(base, Mapping):
base = base[key]
elif isinstance(base, Sequence):
base = base[int(key)]
else:
raise KeyError()
except KeyError as e:
if "default" in kwargs:
return kwargs["default"]
raise e
return base
def _to_pinnable(obj):
"""Converts obj to a form that can be pinned in object store memory.

View file

@ -0,0 +1,216 @@
from typing import Dict, List, Union, Optional, TypeVar
import copy
from collections import deque
from collections.abc import Mapping, Sequence
T = TypeVar("T")
def merge_dicts(d1: dict, d2: dict) -> dict:
"""
Args:
d1 (dict): Dict 1.
d2 (dict): Dict 2.
Returns:
dict: A new dict that is d1 and d2 deep merged.
"""
merged = copy.deepcopy(d1)
deep_update(merged, d2, True, [])
return merged
def deep_update(
original: dict,
new_dict: dict,
new_keys_allowed: str = False,
allow_new_subkey_list: Optional[List[str]] = None,
override_all_if_type_changes: Optional[List[str]] = None) -> dict:
"""Updates original dict with values from new_dict recursively.
If new key is introduced in new_dict, then if new_keys_allowed is not
True, an error will be thrown. Further, for sub-dicts, if the key is
in the allow_new_subkey_list, then new subkeys can be introduced.
Args:
original (dict): Dictionary with default values.
new_dict (dict): Dictionary with values to be updated
new_keys_allowed (bool): Whether new keys are allowed.
allow_new_subkey_list (Optional[List[str]]): List of keys that
correspond to dict values where new subkeys can be introduced.
This is only at the top level.
override_all_if_type_changes(Optional[List[str]]): List of top level
keys with value=dict, for which we always simply override the
entire value (dict), iff the "type" key in that value dict changes.
"""
allow_new_subkey_list = allow_new_subkey_list or []
override_all_if_type_changes = override_all_if_type_changes or []
for k, value in new_dict.items():
if k not in original and not new_keys_allowed:
raise Exception("Unknown config parameter `{}` ".format(k))
# Both orginal value and new one are dicts.
if isinstance(original.get(k), dict) and isinstance(value, dict):
# Check old type vs old one. If different, override entire value.
if k in override_all_if_type_changes and \
"type" in value and "type" in original[k] and \
value["type"] != original[k]["type"]:
original[k] = value
# Allowed key -> ok to add new subkeys.
elif k in allow_new_subkey_list:
deep_update(original[k], value, True)
# Non-allowed key.
else:
deep_update(original[k], value, new_keys_allowed)
# Original value not a dict OR new value not a dict:
# Override entire value.
else:
original[k] = value
return original
def flatten_dict(dt: Dict,
delimiter: str = "/",
prevent_delimiter: bool = False,
flatten_list: bool = False):
"""Flatten dict.
Output and input are of the same dict type.
Input dict remains the same after the operation.
"""
def _raise_delimiter_exception():
raise ValueError(
f"Found delimiter `{delimiter}` in key when trying to flatten "
f"array. Please avoid using the delimiter in your specification.")
dt = copy.copy(dt)
if prevent_delimiter and any(delimiter in key for key in dt):
# Raise if delimiter is any of the keys
_raise_delimiter_exception()
while_check = (dict, list) if flatten_list else dict
while any(isinstance(v, while_check) for v in dt.values()):
remove = []
add = {}
for key, value in dt.items():
if isinstance(value, dict):
for subkey, v in value.items():
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
_raise_delimiter_exception()
add[delimiter.join([key, str(subkey)])] = v
remove.append(key)
elif flatten_list and isinstance(value, list):
for i, v in enumerate(value):
if prevent_delimiter and delimiter in subkey:
# Raise if delimiter is in any of the subkeys
_raise_delimiter_exception()
add[delimiter.join([key, str(i)])] = v
remove.append(key)
dt.update(add)
for k in remove:
del dt[k]
return dt
def unflatten_dict(dt: Dict[str, T], delimiter: str = "/") -> Dict[str, T]:
"""Unflatten dict. Does not support unflattening lists."""
dict_type = type(dt)
out = dict_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for k in path[:-1]:
item = item.setdefault(k, dict_type())
if not isinstance(item, dict_type):
raise TypeError(
f"Cannot unflatten dict due the key '{key}' "
f"having a parent key '{k}', which value is not "
f"of type {dict_type} (got {type(item)}). "
"Change the key names to resolve the conflict.")
item[path[-1]] = val
return out
def unflatten_list_dict(dt: Dict[str, T],
delimiter: str = "/") -> Dict[str, T]:
"""Unflatten nested dict and list.
This function now has some limitations:
(1) The keys of dt must be str.
(2) If unflattened dt (the result) contains list, the index order must be
ascending when accessing dt. Otherwise, this function will throw
AssertionError.
(3) The unflattened dt (the result) shouldn't contain dict with number
keys.
Be careful to use this function. If you want to improve this function,
please also improve the unit test. See #14487 for more details.
Args:
dt (dict): Flattened dictionary that is originally nested by multiple
list and dict.
delimiter (str): Delimiter of keys.
Example:
>>> dt = {"aaa/0/bb": 12, "aaa/1/cc": 56, "aaa/1/dd": 92}
>>> unflatten_list_dict(dt)
{'aaa': [{'bb': 12}, {'cc': 56, 'dd': 92}]}
"""
out_type = list if list(dt)[0].split(delimiter, 1)[0].isdigit() \
else type(dt)
out = out_type()
for key, val in dt.items():
path = key.split(delimiter)
item = out
for i, k in enumerate(path[:-1]):
next_type = list if path[i + 1].isdigit() else dict
if isinstance(item, dict):
item = item.setdefault(k, next_type())
elif isinstance(item, list):
if int(k) >= len(item):
item.append(next_type())
assert int(k) == len(item) - 1
item = item[int(k)]
if isinstance(item, dict):
item[path[-1]] = val
elif isinstance(item, list):
item.append(val)
assert int(path[-1]) == len(item) - 1
return out
def unflattened_lookup(flat_key: str,
lookup: Union[Mapping, Sequence],
delimiter: str = "/",
**kwargs) -> Union[Mapping, Sequence]:
"""
Unflatten `flat_key` and iteratively look up in `lookup`. E.g.
`flat_key="a/0/b"` will try to return `lookup["a"][0]["b"]`.
"""
if flat_key in lookup:
return lookup[flat_key]
keys = deque(flat_key.split(delimiter))
base = lookup
while keys:
key = keys.popleft()
try:
if isinstance(base, Mapping):
base = base[key]
elif isinstance(base, Sequence):
base = base[int(key)]
else:
raise KeyError()
except KeyError as e:
if "default" in kwargs:
return kwargs["default"]
raise e
return base

View file

@ -50,7 +50,8 @@ class BackendExecutor:
generated.
Attributes:
logdir (Path): Path to the file directory where logs will be persisted.
logdir (Path): Path to the file directory where logs will be
persisted.
latest_run_dir (Optional[Path]): Path to the file directory for the
latest run. Configured through ``start_training``.
latest_checkpoint_dir (Optional[Path]): Path to the file directory for

View file

@ -1,4 +1,5 @@
from ray.util.sgd.v2.callbacks.callback import SGDCallback
from ray.util.sgd.v2.callbacks.logging import JsonLoggerCallback
from ray.util.sgd.v2.callbacks.logging import (JsonLoggerCallback,
TBXLoggerCallback)
__all__ = ["SGDCallback", "JsonLoggerCallback"]
__all__ = ["SGDCallback", "JsonLoggerCallback", "TBXLoggerCallback"]

View file

@ -1,21 +1,36 @@
import abc
from typing import List, Optional, Dict
from typing import List, Dict
class SGDCallback(metaclass=abc.ABCMeta):
"""Abstract SGD callback class."""
def handle_result(self, results: Optional[List[Dict]]):
"""Called every time sgd.report() is called."""
def handle_result(self, results: List[Dict], **info):
"""Called every time sgd.report() is called.
Args:
results (List[Dict]): List of results from the training
function. Each value in the list corresponds to the output of
the training function from each worker.
**info: kwargs dict for forward compatibility.
"""
pass
def start_training(self):
"""Called once on training start."""
def start_training(self, logdir: str, **info):
"""Called once on training start.
Args:
logdir (str): Path to the file directory where logs
should be persisted.
**info: kwargs dict for forward compatibility.
"""
pass
def finish_training(self, error: bool = False):
def finish_training(self, error: bool = False, **info):
"""Called once after training is over.
Args:
error (bool): If True, there was an exception during training."""
error (bool): If True, there was an exception during training.
**info: kwargs dict for forward compatibility.
"""
pass

View file

@ -1,22 +1,52 @@
from typing import Iterable, List, Optional, Dict, Union
from typing import Iterable, List, Optional, Dict, Set, Tuple, Union
import abc
import warnings
import logging
import numpy as np
import json
from pathlib import Path
from ray.util.debug import log_once
from ray.util.ml_utils.dict import flatten_dict
from ray.util.ml_utils.json import SafeFallbackEncoder
from ray.util.sgd.v2.callbacks import SGDCallback
from ray.util.sgd.v2.constants import RESULT_FILE_JSON
from ray.util.sgd.v2.constants import (RESULT_FILE_JSON, TRAINING_ITERATION,
TIME_TOTAL_S, TIMESTAMP, PID)
logger = logging.getLogger(__name__)
class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta):
class SGDLogdirMixin:
def start_training(self, logdir: str, **info):
if self._logdir:
logdir_path = Path(self._logdir)
else:
logdir_path = Path(logdir)
if not logdir_path.is_dir():
raise ValueError(f"logdir '{logdir}' must be a directory.")
self._logdir_path = logdir_path
@property
def logdir(self) -> Optional[Path]:
"""Path to currently used logging directory."""
if not hasattr(self, "_logdir_path"):
return Path(self._logdir)
return Path(self._logdir_path)
class SGDSingleFileLoggingCallback(
SGDLogdirMixin, SGDCallback, metaclass=abc.ABCMeta):
"""Abstract SGD logging callback class.
Args:
logdir (str): Path to directory where the results file should be.
filename (str|None): Filename in logdir to save results to.
logdir (Optional[str]): Path to directory where the results file
should be. If None, will be set by the Trainer.
filename (Optional[str]): Filename in logdir to save results to.
workers_to_log (int|List[int]|None): Worker indices to log.
If None, will log all workers.
If None, will log all workers. By default, will log the
worker with index 0.
"""
# Defining it like this ensures it will be overwritten
@ -24,18 +54,15 @@ class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta):
_default_filename: Union[str, Path]
def __init__(self,
logdir: str,
logdir: Optional[str] = None,
filename: Optional[str] = None,
workers_to_log: Optional[Union[int, List[int]]] = 0) -> None:
logdir_path = Path(logdir)
self._logdir = logdir
self._filename = filename
self._workers_to_log = self._validate_workers_to_log(workers_to_log)
self._log_path = None
if not logdir_path.is_dir():
raise ValueError(f"logdir '{logdir}' must be a directory.")
if filename is None:
filename = self._default_filename
self._log_path = logdir_path.joinpath(Path(filename))
def _validate_workers_to_log(self, workers_to_log) -> List[int]:
if isinstance(workers_to_log, int):
workers_to_log = [workers_to_log]
@ -46,35 +73,58 @@ class SGDSingleFileLoggingCallback(SGDCallback, metaclass=abc.ABCMeta):
if not all(isinstance(worker, int) for worker in workers_to_log):
raise TypeError(
"All elements of workers_to_log must be integers.")
if len(workers_to_log) < 1:
raise ValueError(
"At least one worker must be specified in workers_to_log.")
return workers_to_log
self._workers_to_log = workers_to_log
def _create_log_path(self, logdir_path: Path, filename: Path) -> Path:
if not filename:
raise ValueError("filename cannot be None or empty.")
return logdir_path.joinpath(Path(filename))
def start_training(self, logdir: str, **info):
super().start_training(logdir, **info)
if not self._filename:
filename = self._default_filename
else:
filename = self._filename
self._log_path = self._create_log_path(self.logdir, filename)
@property
def log_path(self) -> Path:
"""Path to the log file."""
return self._log_path
def log_path(self) -> Optional[Path]:
"""Path to the log file.
def start_training(self):
# Create a JSON file with an empty list
# that will be latter appended to
with open(self._log_path, "w") as f:
json.dump([], f, cls=SafeFallbackEncoder)
Will be None before `start_training` is called for the first time.
"""
return self._log_path
class JsonLoggerCallback(SGDSingleFileLoggingCallback):
"""Logs SGD results in json format.
Args:
logdir (str): Path to directory where the results file should be.
filename (str|None): Filename in logdir to save results to.
Defaults to "results.json".
logdir (Optional[str]): Path to directory where the results file
should be. If None, will be set by the Trainer.
filename (Optional[str]): Filename in logdir to save results to.
workers_to_log (int|List[int]|None): Worker indices to log.
If None, will log all workers.
If None, will log all workers. By default, will log the
worker with index 0.
"""
_default_filename: Union[str, Path] = RESULT_FILE_JSON
def handle_result(self, results: Optional[List[Dict]]):
def start_training(self, logdir: str, **info):
super().start_training(logdir, **info)
# Create a JSON file with an empty list
# that will be latter appended to
with open(self._log_path, "w") as f:
json.dump([], f, cls=SafeFallbackEncoder)
def handle_result(self, results: List[Dict], **info):
if self._workers_to_log is None or results is None:
results_to_log = results
else:
@ -87,3 +137,105 @@ class JsonLoggerCallback(SGDSingleFileLoggingCallback):
f.seek(0)
json.dump(
loaded_results + [results_to_log], f, cls=SafeFallbackEncoder)
class SGDSingleWorkerLoggingCallback(
SGDLogdirMixin, SGDCallback, metaclass=abc.ABCMeta):
"""Abstract SGD logging callback class.
Allows only for single-worker logging.
Args:
logdir (Optional[str]): Path to directory where the results file
should be. If None, will be set by the Trainer.
worker_to_log (int): Worker index to log. By default, will log the
worker with index 0.
"""
def __init__(self, logdir: Optional[str] = None,
worker_to_log: int = 0) -> None:
self._logdir = logdir
self._workers_to_log = self._validate_worker_to_log(worker_to_log)
self._log_path = None
def _validate_worker_to_log(self, worker_to_log) -> int:
if isinstance(worker_to_log, Iterable):
worker_to_log = list(worker_to_log)
if len(worker_to_log) > 1:
raise ValueError(
f"{self.__class__.__name__} only supports logging "
"from a single worker.")
elif len(worker_to_log) < 1:
raise ValueError(
"At least one worker must be specified in workers_to_log.")
worker_to_log = worker_to_log[0]
if not isinstance(worker_to_log, int):
raise TypeError("workers_to_log must be an integer.")
return worker_to_log
class TBXLoggerCallback(SGDSingleWorkerLoggingCallback):
"""Logs SGD results in TensorboardX format.
Args:
logdir (Optional[str]): Path to directory where the results file
should be. If None, will be set by the Trainer.
worker_to_log (int): Worker index to log. By default, will log the
worker with index 0.
"""
VALID_SUMMARY_TYPES: Tuple[type] = (int, float, np.float32, np.float64,
np.int32, np.int64)
IGNORE_KEYS: Set[str] = {PID, TIMESTAMP, TIME_TOTAL_S, TRAINING_ITERATION}
def start_training(self, logdir: str, **info):
super().start_training(logdir, **info)
try:
from tensorboardX import SummaryWriter
except ImportError:
if log_once("tbx-install"):
warnings.warn(
"pip install 'tensorboardX' to see TensorBoard files.")
raise
self._file_writer = SummaryWriter(self.logdir, flush_secs=30)
def handle_result(self, results: List[Dict], **info):
result = results[self._workers_to_log]
step = result[TRAINING_ITERATION]
result = {k: v for k, v in result.items() if k not in self.IGNORE_KEYS}
flat_result = flatten_dict(result, delimiter="/")
path = ["ray", "sgd"]
# same logic as in ray.tune.logger.TBXLogger
for attr, value in flat_result.items():
full_attr = "/".join(path + [attr])
if (isinstance(value, self.VALID_SUMMARY_TYPES)
and not np.isnan(value)):
self._file_writer.add_scalar(
full_attr, value, global_step=step)
elif ((isinstance(value, list) and len(value) > 0)
or (isinstance(value, np.ndarray) and value.size > 0)):
# Must be video
if isinstance(value, np.ndarray) and value.ndim == 5:
self._file_writer.add_video(
full_attr, value, global_step=step, fps=20)
continue
try:
self._file_writer.add_histogram(
full_attr, value, global_step=step)
# In case TensorboardX still doesn't think it's a valid value
# (e.g. `[[]]`), warn and move on.
except (ValueError, TypeError):
if log_once("invalid_tbx_value"):
warnings.warn(
"You are trying to log an invalid value ({}={}) "
"via {}!".format(full_attr, value,
type(self).__name__))
self._file_writer.flush()
def finish_training(self, error: bool = False, **info):
self._file_writer.close()

View file

@ -1,6 +1,6 @@
# Autofilled sgd.report() metrics. Keys should be consistent with Tune.
from pathlib import Path
# Autofilled sgd.report() metrics. Keys should be consistent with Tune.
TIMESTAMP = "_timestamp"
TIME_THIS_ITER_S = "_time_this_iter_s"
TRAINING_ITERATION = "_training_iteration"

View file

@ -1,11 +1,11 @@
import argparse
import numpy as np
import ray.util.sgd.v2 as sgd
import torch
import torch.nn as nn
import ray.util.sgd.v2 as sgd
from ray.util.sgd.v2 import Trainer, TorchConfig
from ray.util.sgd.v2.callbacks import JsonLoggerCallback
from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DistributedSampler
@ -92,7 +92,10 @@ def train_linear(num_workers=1):
config = {"lr": 1e-2, "hidden_size": 1, "batch_size": 4, "epochs": 3}
trainer.start()
results = trainer.run(
train_func, config, callbacks=[JsonLoggerCallback("./sgd_results")])
train_func,
config,
callbacks=[JsonLoggerCallback(),
TBXLoggerCallback()])
trainer.shutdown()
print(results)
@ -110,7 +113,7 @@ if __name__ == "__main__":
"--num-workers",
"-n",
type=int,
default=1,
default=2,
help="Sets number of workers for training.")
parser.add_argument(
"--smoke-test",

View file

@ -3,6 +3,8 @@ import os
import shutil
import tempfile
import json
import glob
from collections import defaultdict
import ray
import ray.util.sgd.v2 as sgd
@ -10,10 +12,16 @@ from ray.util.sgd.v2 import Trainer
from ray.util.sgd.v2.constants import (
TRAINING_ITERATION, DETAILED_AUTOFILLED_KEYS, BASIC_AUTOFILLED_KEYS,
ENABLE_DETAILED_AUTOFILLED_METRICS_ENV)
from ray.util.sgd.v2.callbacks import JsonLoggerCallback
from ray.util.sgd.v2.callbacks import JsonLoggerCallback, TBXLoggerCallback
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface
from ray.util.sgd.v2.worker_group import WorkerGroup
try:
from tensorflow.python.summary.summary_iterator \
import summary_iterator
except ImportError:
summary_iterator = None
@pytest.fixture
def ray_start_4_cpus():
@ -78,15 +86,17 @@ def test_json(ray_start_4_cpus, make_temp_dir, workers_to_log, detailed,
# if None, use default value
callback = JsonLoggerCallback(
make_temp_dir, workers_to_log=workers_to_log)
assert str(
callback.log_path.name) == JsonLoggerCallback._default_filename
else:
callback = JsonLoggerCallback(
make_temp_dir, filename=filename, workers_to_log=workers_to_log)
assert str(callback.log_path.name) == filename
trainer = Trainer(config, num_workers=num_workers)
trainer.start()
trainer.run(train_func, callbacks=[callback])
if filename is None:
assert str(
callback.log_path.name) == JsonLoggerCallback._default_filename
else:
assert str(callback.log_path.name) == filename
with open(callback.log_path, "r") as f:
log = json.load(f)
@ -110,3 +120,39 @@ def test_json(ray_start_4_cpus, make_temp_dir, workers_to_log, detailed,
assert all(
all(not any(key in worker for key in DETAILED_AUTOFILLED_KEYS)
for worker in element) for element in log)
def _validate_tbx_result(events_dir):
events_file = list(glob.glob(f"{events_dir}/events*"))[0]
results = defaultdict(list)
for event in summary_iterator(events_file):
for v in event.summary.value:
assert v.tag.startswith("ray/sgd")
results[v.tag[8:]].append(v.simple_value)
assert len(results["episode_reward_mean"]) == 3
assert [int(res) for res in results["episode_reward_mean"]] == [4, 5, 6]
assert len(results["score"]) == 1
assert len(results["hello/world"]) == 1
@pytest.mark.skipif(
summary_iterator is None, reason="tensorboard is not installed")
def test_TBX(ray_start_4_cpus, make_temp_dir):
config = TestConfig()
temp_dir = make_temp_dir
num_workers = 4
def train_func():
sgd.report(episode_reward_mean=4)
sgd.report(episode_reward_mean=5)
sgd.report(episode_reward_mean=6, score=[1, 2, 3], hello={"world": 1})
return 1
callback = TBXLoggerCallback(temp_dir)
trainer = Trainer(config, num_workers=num_workers)
trainer.start()
trainer.run(train_func, callbacks=[callback])
_validate_tbx_result(temp_dir)

View file

@ -7,7 +7,6 @@ import ray
import ray.util.sgd.v2 as sgd
import tensorflow as tf
import torch
from ray.util.sgd.v2 import Trainer
from ray.util.sgd.v2.backends.backend import BackendConfig, BackendInterface, \
BackendExecutor

View file

@ -158,7 +158,7 @@ class Trainer:
finished_with_errors = False
for callback in callbacks:
callback.start_training()
callback.start_training(logdir=self.logdir)
try:
iterator = self.run_iterator(