[AIR] Add rich notebook repr for DataParallelTrainer (#26335)

This commit is contained in:
Peyton Murray 2022-08-16 08:51:14 -07:00 committed by GitHub
parent bceef503b2
commit 4d19c0222b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 502 additions and 4 deletions

View file

@ -14,6 +14,7 @@ from typing import (
from ray.air.constants import WILDCARD_KEY
from ray.util.annotations import PublicAPI
from ray.widgets import Template, make_table_html_repr
# Move here later when ml_utils is deprecated. Doing it now causes a circular import.
@ -135,6 +136,9 @@ class ScalingConfig:
def __repr__(self):
return _repr_dataclass(self)
def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)
def __eq__(self, o: "ScalingConfig") -> bool:
if not isinstance(o, type(self)):
return False
@ -323,6 +327,11 @@ class DatasetConfig:
def __repr__(self):
return _repr_dataclass(self)
def _repr_html_(self, title=None) -> str:
if title is None:
title = type(self).__name__
return make_table_html_repr(obj=self, title=title)
def fill_defaults(self) -> "DatasetConfig":
"""Return a copy of this config with all default values filled in."""
return DatasetConfig(
@ -460,6 +469,28 @@ class FailureConfig:
def __repr__(self):
return _repr_dataclass(self)
def _repr_html_(self):
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
return Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Setting": ["Max failures", "Fail fast"],
"Value": [self.max_failures, self.fail_fast],
},
tablefmt="html",
showindex=False,
headers="keys",
),
max_height="none",
)
@dataclass
@PublicAPI(stability="beta")
@ -527,6 +558,55 @@ class CheckpointConfig:
def __repr__(self):
return _repr_dataclass(self)
def _repr_html_(self) -> str:
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
if self.num_to_keep is None:
num_to_keep_repr = "All"
else:
num_to_keep_repr = self.num_to_keep
if self.checkpoint_score_attribute is None:
checkpoint_score_attribute_repr = "Most recent"
else:
checkpoint_score_attribute_repr = self.checkpoint_score_attribute
if self.checkpoint_at_end is None:
checkpoint_at_end_repr = ""
else:
checkpoint_at_end_repr = self.checkpoint_at_end
return Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Setting": [
"Number of checkpoints to keep",
"Checkpoint score attribute",
"Checkpoint score order",
"Checkpoint frequency",
"Checkpoint at end",
],
"Value": [
num_to_keep_repr,
checkpoint_score_attribute_repr,
self.checkpoint_score_order,
self.checkpoint_frequency,
checkpoint_at_end_repr,
],
},
tablefmt="html",
showindex=False,
headers="keys",
),
max_height="none",
)
@property
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
"""Same as ``checkpoint_score_attr`` in ``tune.run``.
@ -618,3 +698,59 @@ class RunConfig:
"checkpoint_config": CheckpointConfig(),
},
)
def _repr_html_(self) -> str:
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
reprs = []
if self.failure_config is not None:
reprs.append(
Template("title_data_mini.html.j2").render(
title="Failure Config", data=self.failure_config._repr_html_()
)
)
if self.sync_config is not None:
reprs.append(
Template("title_data_mini.html.j2").render(
title="Sync Config", data=self.sync_config._repr_html_()
)
)
if self.checkpoint_config is not None:
reprs.append(
Template("title_data_mini.html.j2").render(
title="Checkpoint Config", data=self.checkpoint_config._repr_html_()
)
)
# Create a divider between each displayed repr
subconfigs = [Template("divider.html.j2").render()] * (2 * len(reprs) - 1)
subconfigs[::2] = reprs
settings = Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Name": self.name,
"Local results directory": self.local_dir,
"Verbosity": self.verbose,
"Log to file": self.log_to_file,
}.items(),
tablefmt="html",
headers=["Setting", "Value"],
showindex=False,
),
max_height="300px",
)
return Template("title_data.html.j2").render(
title="RunConfig",
data=Template("run_config.html.j2").render(
subconfigs=subconfigs,
settings=settings,
),
)

View file

@ -3,6 +3,7 @@ import itertools
import logging
import os
import time
import html
from typing import (
TYPE_CHECKING,
Any,
@ -92,6 +93,7 @@ from ray.data.random_access_dataset import RandomAccessDataset
from ray.data.row import TableRow
from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.widgets import Template
if TYPE_CHECKING:
import dask
@ -3599,6 +3601,83 @@ class Dataset(Generic[T]):
else:
return result
def _ipython_display_(self):
try:
from ipywidgets import HTML, VBox, Layout
except ImportError:
logger.warn(
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
"enable notebook widgets."
)
return None
from IPython.display import display
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
display(VBox([title, self._tab_repr_()], layout=Layout(width="100%")))
def _tab_repr_(self):
try:
from tabulate import tabulate
from ipywidgets import Tab, HTML
except ImportError:
logger.info(
"For rich Dataset reprs in notebooks, run "
"`pip install tabulate ipywidgets`."
)
return ""
metadata = {
"num_blocks": self._plan.initial_num_blocks(),
"num_rows": self._meta_count(),
}
schema = self.schema()
if schema is None:
schema_repr = Template("rendered_html_common.html.j2").render(
content="<h5>Unknown schema</h5>"
)
elif isinstance(schema, type):
schema_repr = Template("rendered_html_common.html.j2").render(
content=f"<h5>Data type: <code>{html.escape(str(schema))}</code></h5>"
)
else:
schema_data = {}
for sname, stype in zip(schema.names, schema.types):
schema_data[sname] = getattr(stype, "__name__", str(stype))
schema_repr = Template("scrollableTable.html.j2").render(
table=tabulate(
tabular_data=schema_data.items(),
tablefmt="html",
showindex=False,
headers=["Name", "Type"],
),
max_height="300px",
)
tab = Tab()
children = []
tab.set_title(0, "Metadata")
children.append(
Template("scrollableTable.html.j2").render(
table=tabulate(
tabular_data=metadata.items(),
tablefmt="html",
showindex=False,
headers=["Field", "Value"],
),
max_height="300px",
)
)
tab.set_title(1, "Schema")
children.append(schema_repr)
tab.children = [HTML(child) for child in children]
return tab
def __repr__(self) -> str:
schema = self.schema()
if schema is None:

View file

@ -4,6 +4,7 @@ from typing import TypeVar, Dict
from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI
from ray.widgets import make_table_html_repr
EncodedData = TypeVar("EncodedData")
@ -18,6 +19,9 @@ class BackendConfig:
def backend_cls(self):
return Backend
def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)
@DeveloperAPI
class Backend(metaclass=Singleton):

View file

@ -3,6 +3,7 @@ import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
from tabulate import tabulate
import ray
from ray import tune
@ -19,6 +20,7 @@ from ray.train._internal.utils import construct_train_func
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
from ray.train.trainer import BaseTrainer, GenDataset
from ray.util.annotations import DeveloperAPI
from ray.widgets import Template
if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor
@ -372,6 +374,115 @@ class DataParallelTrainer(BaseTrainer):
"""
return self._dataset_config.copy()
def _ipython_display_(self):
try:
from ipywidgets import HTML, VBox, Tab, Layout
except ImportError:
logger.warn(
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
"enable notebook widgets."
)
return None
from IPython.display import display
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
tab = Tab()
children = []
tab.set_title(0, "Datasets")
children.append(self._datasets_repr_() if self.datasets else None)
tab.set_title(1, "Dataset Config")
children.append(
HTML(self._dataset_config_repr_html_()) if self._dataset_config else None
)
tab.set_title(2, "Train Loop Config")
children.append(
HTML(self._train_loop_config_repr_html_())
if self._train_loop_config
else None
)
tab.set_title(3, "Scaling Config")
children.append(
HTML(self.scaling_config._repr_html_()) if self.scaling_config else None
)
tab.set_title(4, "Run Config")
children.append(
HTML(self.run_config._repr_html_()) if self.run_config else None
)
tab.set_title(5, "Backend Config")
children.append(
HTML(self._backend_config._repr_html_()) if self._backend_config else None
)
tab.children = children
display(VBox([title, tab], layout=Layout(width="100%")))
def _train_loop_config_repr_html_(self) -> str:
if self._train_loop_config:
table_data = {}
for k, v in self._train_loop_config.items():
if isinstance(v, str) or str(v).isnumeric():
table_data[k] = v
elif hasattr(v, "_repr_html_"):
table_data[k] = v._repr_html_()
else:
table_data[k] = str(v)
return Template("title_data.html.j2").render(
title="Train Loop Config",
data=Template("scrollableTable.html.j2").render(
table=tabulate(
table_data.items(),
headers=["Setting", "Value"],
showindex=False,
tablefmt="unsafehtml",
),
max_height="none",
),
)
else:
return ""
def _dataset_config_repr_html_(self) -> str:
content = []
if self._dataset_config:
for name, config in self._dataset_config.items():
content.append(
config._repr_html_(title=f"DatasetConfig - <code>{name}</code>")
)
return Template("rendered_html_common.html.j2").render(content=content)
def _datasets_repr_(self) -> str:
try:
from ipywidgets import HTML, VBox, Layout
except ImportError:
logger.warn(
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
"enable notebook widgets."
)
return None
content = []
if self.datasets:
for name, config in self.datasets.items():
content.append(
HTML(
Template("title_data.html.j2").render(
title=f"Dataset - <code>{name}</code>", data=None
)
)
)
content.append(config._tab_repr_())
return VBox(content, layout=Layout(width="100%"))
def _load_checkpoint(
checkpoint: Checkpoint, trainer_name: str

View file

@ -30,6 +30,7 @@ from ray.tune.callback import Callback
from ray.tune.result import NODE_IP
from ray.tune.utils.file_transfer import sync_dir_between_nodes
from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.widgets import Template
if TYPE_CHECKING:
from ray.tune.experiment import Trial
@ -93,6 +94,41 @@ class SyncConfig:
sync_on_checkpoint: bool = True
sync_period: int = DEFAULT_SYNC_PERIOD
def _repr_html_(self) -> str:
"""Generate an HTML representation of the SyncConfig.
Note that self.syncer is omitted here; seems to have some overlap
with existing configuration settings here in the SyncConfig class.
"""
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
return Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Setting": [
"Upload directory",
"Sync on checkpoint",
"Sync period",
],
"Value": [
self.upload_dir,
self.sync_on_checkpoint,
self.sync_period,
],
},
tablefmt="html",
showindex=False,
headers="keys",
),
max_height="none",
)
class _BackgroundProcess:
def __init__(self, fn: Callable):
@ -308,6 +344,9 @@ class Syncer(abc.ABC):
def close(self):
pass
def _repr_html_(self) -> str:
return
class _BackgroundSyncer(Syncer):
"""Syncer using a background process for asynchronous file transfer."""

View file

@ -1,3 +1,4 @@
from ray.widgets.render import Template
from ray.widgets.util import make_table_html_repr
__all__ = ["Template"]
__all__ = ["Template", "make_table_html_repr"]

View file

@ -24,6 +24,8 @@ class Template:
"""
rendered = self.template
for key, value in kwargs.items():
if isinstance(value, List):
value = "".join(value)
rendered = rendered.replace("{{ " + key + " }}", value if value else "")
return rendered

View file

@ -0,0 +1,9 @@
<div class="vDivider"></div>
<style>
.vDivider {
border-left-width: var(--jp-border-width);
border-left-color: var(--jp-border-color0);
border-left-style: solid;
margin: 0.5em 1em 0.5em 1em;
}
</style>

View file

@ -0,0 +1,3 @@
<div class='jp-RenderedHTMLCommon'>
{{ content }}
</div>

View file

@ -0,0 +1,18 @@
<div class="runConfig">
<div class="generalSettings">
{{ settings }}
</div>
<div class="sideBySide">
{{ subconfigs }}
</div>
</div>
<style>
.sideBySide {
display: flex;
flex-direction: row;
gap: 1em;
}
.generalSettings {
border-bottom: var(--jp-border-width) solid var(--jp-border-color0);
}
</style>

View file

@ -0,0 +1,20 @@
<div class="scrollableTable jp-RenderedHTMLCommon">
{{ table }}
</div>
<style>
.scrollableTable {
overflow-y: auto;
max-height: {{ max_height }};
}
.scrollableTable table {
width: 100%;
}
.scrollableTable table :is(th,td) {
text-align: left !important;
}
.scrollableTable th {
background: var(--jp-layout-color1);
position: sticky;
top: 0;
}
</style>

View file

@ -0,0 +1,11 @@
<div class='titleData jp-RenderedHTMLCommon'>
<h3>{{ title }}</h3>
{{ data }}
</div>
<style>
.titleData h3 {
border-bottom-width: var(--jp-border-width);
border-bottom-color: var(--jp-border-color0);
border-bottom-style: solid;
}
</style>

View file

@ -0,0 +1,4 @@
<div class='miniTitleData jp-RenderedHTMLCommon'>
<h4><b>{{ title }}</b></h4>
{{ data }}
</div>

View file

@ -0,0 +1,61 @@
from typing import Any, Optional
from ray.util.annotations import DeveloperAPI
from ray.widgets import Template
@DeveloperAPI
def make_table_html_repr(
obj: Any, title: Optional[str] = None, max_height: str = "none"
) -> str:
"""Generate a generic html repr using a table.
Args:
obj: Object for which a repr is to be generated
title: If present, a title for the section is included
max_height: Maximum height of the table; valid values
are given by the max-height CSS property
Returns:
HTML representation of the object
"""
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
data = {}
for k, v in vars(obj).items():
if isinstance(v, (str, bool, int, float)):
data[k] = str(v)
elif isinstance(v, dict) or hasattr(v, "__dict__"):
data[k] = Template("scrollableTable.html.j2").render(
table=tabulate(
v.items() if isinstance(v, dict) else vars(v).items(),
tablefmt="html",
showindex=False,
headers=["Setting", "Value"],
),
max_height="none",
)
table = Template("scrollableTable.html.j2").render(
table=tabulate(
data.items(),
tablefmt="unsafehtml",
showindex=False,
headers=["Setting", "Value"],
),
max_height=max_height,
)
if title:
content = Template("title_data.html.j2").render(title=title, data=table)
else:
content = table
return content

View file

@ -4,6 +4,7 @@ import glob
import io
import logging
import os
import pathlib
import re
import shutil
import subprocess
@ -194,10 +195,9 @@ ray_files += [
for filename in filenames
]
# Files for ray.init html template.
# html templates for notebook integration
ray_files += [
"ray/widgets/templates/context_dashrow.html.j2",
"ray/widgets/templates/context.html.j2",
p.as_posix() for p in pathlib.Path("ray/widgets/templates/").glob("*.html.j2")
]
# If you're adding dependencies for ray extras, please