[AIR] Add rich notebook repr for DataParallelTrainer (#26335)

This commit is contained in:
Peyton Murray 2022-08-16 08:51:14 -07:00 committed by GitHub
parent bceef503b2
commit 4d19c0222b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 502 additions and 4 deletions

View file

@ -14,6 +14,7 @@ from typing import (
from ray.air.constants import WILDCARD_KEY from ray.air.constants import WILDCARD_KEY
from ray.util.annotations import PublicAPI from ray.util.annotations import PublicAPI
from ray.widgets import Template, make_table_html_repr
# Move here later when ml_utils is deprecated. Doing it now causes a circular import. # Move here later when ml_utils is deprecated. Doing it now causes a circular import.
@ -135,6 +136,9 @@ class ScalingConfig:
def __repr__(self): def __repr__(self):
return _repr_dataclass(self) return _repr_dataclass(self)
def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)
def __eq__(self, o: "ScalingConfig") -> bool: def __eq__(self, o: "ScalingConfig") -> bool:
if not isinstance(o, type(self)): if not isinstance(o, type(self)):
return False return False
@ -323,6 +327,11 @@ class DatasetConfig:
def __repr__(self): def __repr__(self):
return _repr_dataclass(self) return _repr_dataclass(self)
def _repr_html_(self, title=None) -> str:
if title is None:
title = type(self).__name__
return make_table_html_repr(obj=self, title=title)
def fill_defaults(self) -> "DatasetConfig": def fill_defaults(self) -> "DatasetConfig":
"""Return a copy of this config with all default values filled in.""" """Return a copy of this config with all default values filled in."""
return DatasetConfig( return DatasetConfig(
@ -460,6 +469,28 @@ class FailureConfig:
def __repr__(self): def __repr__(self):
return _repr_dataclass(self) return _repr_dataclass(self)
def _repr_html_(self):
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
return Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Setting": ["Max failures", "Fail fast"],
"Value": [self.max_failures, self.fail_fast],
},
tablefmt="html",
showindex=False,
headers="keys",
),
max_height="none",
)
@dataclass @dataclass
@PublicAPI(stability="beta") @PublicAPI(stability="beta")
@ -527,6 +558,55 @@ class CheckpointConfig:
def __repr__(self): def __repr__(self):
return _repr_dataclass(self) return _repr_dataclass(self)
def _repr_html_(self) -> str:
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
if self.num_to_keep is None:
num_to_keep_repr = "All"
else:
num_to_keep_repr = self.num_to_keep
if self.checkpoint_score_attribute is None:
checkpoint_score_attribute_repr = "Most recent"
else:
checkpoint_score_attribute_repr = self.checkpoint_score_attribute
if self.checkpoint_at_end is None:
checkpoint_at_end_repr = ""
else:
checkpoint_at_end_repr = self.checkpoint_at_end
return Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Setting": [
"Number of checkpoints to keep",
"Checkpoint score attribute",
"Checkpoint score order",
"Checkpoint frequency",
"Checkpoint at end",
],
"Value": [
num_to_keep_repr,
checkpoint_score_attribute_repr,
self.checkpoint_score_order,
self.checkpoint_frequency,
checkpoint_at_end_repr,
],
},
tablefmt="html",
showindex=False,
headers="keys",
),
max_height="none",
)
@property @property
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]: def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
"""Same as ``checkpoint_score_attr`` in ``tune.run``. """Same as ``checkpoint_score_attr`` in ``tune.run``.
@ -618,3 +698,59 @@ class RunConfig:
"checkpoint_config": CheckpointConfig(), "checkpoint_config": CheckpointConfig(),
}, },
) )
def _repr_html_(self) -> str:
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
reprs = []
if self.failure_config is not None:
reprs.append(
Template("title_data_mini.html.j2").render(
title="Failure Config", data=self.failure_config._repr_html_()
)
)
if self.sync_config is not None:
reprs.append(
Template("title_data_mini.html.j2").render(
title="Sync Config", data=self.sync_config._repr_html_()
)
)
if self.checkpoint_config is not None:
reprs.append(
Template("title_data_mini.html.j2").render(
title="Checkpoint Config", data=self.checkpoint_config._repr_html_()
)
)
# Create a divider between each displayed repr
subconfigs = [Template("divider.html.j2").render()] * (2 * len(reprs) - 1)
subconfigs[::2] = reprs
settings = Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Name": self.name,
"Local results directory": self.local_dir,
"Verbosity": self.verbose,
"Log to file": self.log_to_file,
}.items(),
tablefmt="html",
headers=["Setting", "Value"],
showindex=False,
),
max_height="300px",
)
return Template("title_data.html.j2").render(
title="RunConfig",
data=Template("run_config.html.j2").render(
subconfigs=subconfigs,
settings=settings,
),
)

View file

@ -3,6 +3,7 @@ import itertools
import logging import logging
import os import os
import time import time
import html
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -92,6 +93,7 @@ from ray.data.random_access_dataset import RandomAccessDataset
from ray.data.row import TableRow from ray.data.row import TableRow
from ray.types import ObjectRef from ray.types import ObjectRef
from ray.util.annotations import DeveloperAPI, PublicAPI from ray.util.annotations import DeveloperAPI, PublicAPI
from ray.widgets import Template
if TYPE_CHECKING: if TYPE_CHECKING:
import dask import dask
@ -3599,6 +3601,83 @@ class Dataset(Generic[T]):
else: else:
return result return result
def _ipython_display_(self):
try:
from ipywidgets import HTML, VBox, Layout
except ImportError:
logger.warn(
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
"enable notebook widgets."
)
return None
from IPython.display import display
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
display(VBox([title, self._tab_repr_()], layout=Layout(width="100%")))
def _tab_repr_(self):
try:
from tabulate import tabulate
from ipywidgets import Tab, HTML
except ImportError:
logger.info(
"For rich Dataset reprs in notebooks, run "
"`pip install tabulate ipywidgets`."
)
return ""
metadata = {
"num_blocks": self._plan.initial_num_blocks(),
"num_rows": self._meta_count(),
}
schema = self.schema()
if schema is None:
schema_repr = Template("rendered_html_common.html.j2").render(
content="<h5>Unknown schema</h5>"
)
elif isinstance(schema, type):
schema_repr = Template("rendered_html_common.html.j2").render(
content=f"<h5>Data type: <code>{html.escape(str(schema))}</code></h5>"
)
else:
schema_data = {}
for sname, stype in zip(schema.names, schema.types):
schema_data[sname] = getattr(stype, "__name__", str(stype))
schema_repr = Template("scrollableTable.html.j2").render(
table=tabulate(
tabular_data=schema_data.items(),
tablefmt="html",
showindex=False,
headers=["Name", "Type"],
),
max_height="300px",
)
tab = Tab()
children = []
tab.set_title(0, "Metadata")
children.append(
Template("scrollableTable.html.j2").render(
table=tabulate(
tabular_data=metadata.items(),
tablefmt="html",
showindex=False,
headers=["Field", "Value"],
),
max_height="300px",
)
)
tab.set_title(1, "Schema")
children.append(schema_repr)
tab.children = [HTML(child) for child in children]
return tab
def __repr__(self) -> str: def __repr__(self) -> str:
schema = self.schema() schema = self.schema()
if schema is None: if schema is None:

View file

@ -4,6 +4,7 @@ from typing import TypeVar, Dict
from ray.train._internal.utils import Singleton from ray.train._internal.utils import Singleton
from ray.train._internal.worker_group import WorkerGroup from ray.train._internal.worker_group import WorkerGroup
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.widgets import make_table_html_repr
EncodedData = TypeVar("EncodedData") EncodedData = TypeVar("EncodedData")
@ -18,6 +19,9 @@ class BackendConfig:
def backend_cls(self): def backend_cls(self):
return Backend return Backend
def _repr_html_(self) -> str:
return make_table_html_repr(obj=self, title=type(self).__name__)
@DeveloperAPI @DeveloperAPI
class Backend(metaclass=Singleton): class Backend(metaclass=Singleton):

View file

@ -3,6 +3,7 @@ import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
from tabulate import tabulate
import ray import ray
from ray import tune from ray import tune
@ -19,6 +20,7 @@ from ray.train._internal.utils import construct_train_func
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
from ray.train.trainer import BaseTrainer, GenDataset from ray.train.trainer import BaseTrainer, GenDataset
from ray.util.annotations import DeveloperAPI from ray.util.annotations import DeveloperAPI
from ray.widgets import Template
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.data.preprocessor import Preprocessor from ray.data.preprocessor import Preprocessor
@ -372,6 +374,115 @@ class DataParallelTrainer(BaseTrainer):
""" """
return self._dataset_config.copy() return self._dataset_config.copy()
def _ipython_display_(self):
try:
from ipywidgets import HTML, VBox, Tab, Layout
except ImportError:
logger.warn(
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
"enable notebook widgets."
)
return None
from IPython.display import display
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
tab = Tab()
children = []
tab.set_title(0, "Datasets")
children.append(self._datasets_repr_() if self.datasets else None)
tab.set_title(1, "Dataset Config")
children.append(
HTML(self._dataset_config_repr_html_()) if self._dataset_config else None
)
tab.set_title(2, "Train Loop Config")
children.append(
HTML(self._train_loop_config_repr_html_())
if self._train_loop_config
else None
)
tab.set_title(3, "Scaling Config")
children.append(
HTML(self.scaling_config._repr_html_()) if self.scaling_config else None
)
tab.set_title(4, "Run Config")
children.append(
HTML(self.run_config._repr_html_()) if self.run_config else None
)
tab.set_title(5, "Backend Config")
children.append(
HTML(self._backend_config._repr_html_()) if self._backend_config else None
)
tab.children = children
display(VBox([title, tab], layout=Layout(width="100%")))
def _train_loop_config_repr_html_(self) -> str:
if self._train_loop_config:
table_data = {}
for k, v in self._train_loop_config.items():
if isinstance(v, str) or str(v).isnumeric():
table_data[k] = v
elif hasattr(v, "_repr_html_"):
table_data[k] = v._repr_html_()
else:
table_data[k] = str(v)
return Template("title_data.html.j2").render(
title="Train Loop Config",
data=Template("scrollableTable.html.j2").render(
table=tabulate(
table_data.items(),
headers=["Setting", "Value"],
showindex=False,
tablefmt="unsafehtml",
),
max_height="none",
),
)
else:
return ""
def _dataset_config_repr_html_(self) -> str:
content = []
if self._dataset_config:
for name, config in self._dataset_config.items():
content.append(
config._repr_html_(title=f"DatasetConfig - <code>{name}</code>")
)
return Template("rendered_html_common.html.j2").render(content=content)
def _datasets_repr_(self) -> str:
try:
from ipywidgets import HTML, VBox, Layout
except ImportError:
logger.warn(
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
"enable notebook widgets."
)
return None
content = []
if self.datasets:
for name, config in self.datasets.items():
content.append(
HTML(
Template("title_data.html.j2").render(
title=f"Dataset - <code>{name}</code>", data=None
)
)
)
content.append(config._tab_repr_())
return VBox(content, layout=Layout(width="100%"))
def _load_checkpoint( def _load_checkpoint(
checkpoint: Checkpoint, trainer_name: str checkpoint: Checkpoint, trainer_name: str

View file

@ -30,6 +30,7 @@ from ray.tune.callback import Callback
from ray.tune.result import NODE_IP from ray.tune.result import NODE_IP
from ray.tune.utils.file_transfer import sync_dir_between_nodes from ray.tune.utils.file_transfer import sync_dir_between_nodes
from ray.util.annotations import PublicAPI, DeveloperAPI from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.widgets import Template
if TYPE_CHECKING: if TYPE_CHECKING:
from ray.tune.experiment import Trial from ray.tune.experiment import Trial
@ -93,6 +94,41 @@ class SyncConfig:
sync_on_checkpoint: bool = True sync_on_checkpoint: bool = True
sync_period: int = DEFAULT_SYNC_PERIOD sync_period: int = DEFAULT_SYNC_PERIOD
def _repr_html_(self) -> str:
"""Generate an HTML representation of the SyncConfig.
Note that self.syncer is omitted here; seems to have some overlap
with existing configuration settings here in the SyncConfig class.
"""
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
return Template("scrollableTable.html.j2").render(
table=tabulate(
{
"Setting": [
"Upload directory",
"Sync on checkpoint",
"Sync period",
],
"Value": [
self.upload_dir,
self.sync_on_checkpoint,
self.sync_period,
],
},
tablefmt="html",
showindex=False,
headers="keys",
),
max_height="none",
)
class _BackgroundProcess: class _BackgroundProcess:
def __init__(self, fn: Callable): def __init__(self, fn: Callable):
@ -308,6 +344,9 @@ class Syncer(abc.ABC):
def close(self): def close(self):
pass pass
def _repr_html_(self) -> str:
return
class _BackgroundSyncer(Syncer): class _BackgroundSyncer(Syncer):
"""Syncer using a background process for asynchronous file transfer.""" """Syncer using a background process for asynchronous file transfer."""

View file

@ -1,3 +1,4 @@
from ray.widgets.render import Template from ray.widgets.render import Template
from ray.widgets.util import make_table_html_repr
__all__ = ["Template"] __all__ = ["Template", "make_table_html_repr"]

View file

@ -24,6 +24,8 @@ class Template:
""" """
rendered = self.template rendered = self.template
for key, value in kwargs.items(): for key, value in kwargs.items():
if isinstance(value, List):
value = "".join(value)
rendered = rendered.replace("{{ " + key + " }}", value if value else "") rendered = rendered.replace("{{ " + key + " }}", value if value else "")
return rendered return rendered

View file

@ -0,0 +1,9 @@
<div class="vDivider"></div>
<style>
.vDivider {
border-left-width: var(--jp-border-width);
border-left-color: var(--jp-border-color0);
border-left-style: solid;
margin: 0.5em 1em 0.5em 1em;
}
</style>

View file

@ -0,0 +1,3 @@
<div class='jp-RenderedHTMLCommon'>
{{ content }}
</div>

View file

@ -0,0 +1,18 @@
<div class="runConfig">
<div class="generalSettings">
{{ settings }}
</div>
<div class="sideBySide">
{{ subconfigs }}
</div>
</div>
<style>
.sideBySide {
display: flex;
flex-direction: row;
gap: 1em;
}
.generalSettings {
border-bottom: var(--jp-border-width) solid var(--jp-border-color0);
}
</style>

View file

@ -0,0 +1,20 @@
<div class="scrollableTable jp-RenderedHTMLCommon">
{{ table }}
</div>
<style>
.scrollableTable {
overflow-y: auto;
max-height: {{ max_height }};
}
.scrollableTable table {
width: 100%;
}
.scrollableTable table :is(th,td) {
text-align: left !important;
}
.scrollableTable th {
background: var(--jp-layout-color1);
position: sticky;
top: 0;
}
</style>

View file

@ -0,0 +1,11 @@
<div class='titleData jp-RenderedHTMLCommon'>
<h3>{{ title }}</h3>
{{ data }}
</div>
<style>
.titleData h3 {
border-bottom-width: var(--jp-border-width);
border-bottom-color: var(--jp-border-color0);
border-bottom-style: solid;
}
</style>

View file

@ -0,0 +1,4 @@
<div class='miniTitleData jp-RenderedHTMLCommon'>
<h4><b>{{ title }}</b></h4>
{{ data }}
</div>

View file

@ -0,0 +1,61 @@
from typing import Any, Optional
from ray.util.annotations import DeveloperAPI
from ray.widgets import Template
@DeveloperAPI
def make_table_html_repr(
obj: Any, title: Optional[str] = None, max_height: str = "none"
) -> str:
"""Generate a generic html repr using a table.
Args:
obj: Object for which a repr is to be generated
title: If present, a title for the section is included
max_height: Maximum height of the table; valid values
are given by the max-height CSS property
Returns:
HTML representation of the object
"""
try:
from tabulate import tabulate
except ImportError:
return (
"Tabulate isn't installed. Run "
"`pip install tabulate` for rich notebook output."
)
data = {}
for k, v in vars(obj).items():
if isinstance(v, (str, bool, int, float)):
data[k] = str(v)
elif isinstance(v, dict) or hasattr(v, "__dict__"):
data[k] = Template("scrollableTable.html.j2").render(
table=tabulate(
v.items() if isinstance(v, dict) else vars(v).items(),
tablefmt="html",
showindex=False,
headers=["Setting", "Value"],
),
max_height="none",
)
table = Template("scrollableTable.html.j2").render(
table=tabulate(
data.items(),
tablefmt="unsafehtml",
showindex=False,
headers=["Setting", "Value"],
),
max_height=max_height,
)
if title:
content = Template("title_data.html.j2").render(title=title, data=table)
else:
content = table
return content

View file

@ -4,6 +4,7 @@ import glob
import io import io
import logging import logging
import os import os
import pathlib
import re import re
import shutil import shutil
import subprocess import subprocess
@ -194,10 +195,9 @@ ray_files += [
for filename in filenames for filename in filenames
] ]
# Files for ray.init html template. # html templates for notebook integration
ray_files += [ ray_files += [
"ray/widgets/templates/context_dashrow.html.j2", p.as_posix() for p in pathlib.Path("ray/widgets/templates/").glob("*.html.j2")
"ray/widgets/templates/context.html.j2",
] ]
# If you're adding dependencies for ray extras, please # If you're adding dependencies for ray extras, please