mirror of
https://github.com/vale981/ray
synced 2025-03-05 10:01:43 -05:00
[AIR] Add rich notebook repr for DataParallelTrainer (#26335)
This commit is contained in:
parent
bceef503b2
commit
4d19c0222b
15 changed files with 502 additions and 4 deletions
|
@ -14,6 +14,7 @@ from typing import (
|
||||||
|
|
||||||
from ray.air.constants import WILDCARD_KEY
|
from ray.air.constants import WILDCARD_KEY
|
||||||
from ray.util.annotations import PublicAPI
|
from ray.util.annotations import PublicAPI
|
||||||
|
from ray.widgets import Template, make_table_html_repr
|
||||||
|
|
||||||
|
|
||||||
# Move here later when ml_utils is deprecated. Doing it now causes a circular import.
|
# Move here later when ml_utils is deprecated. Doing it now causes a circular import.
|
||||||
|
@ -135,6 +136,9 @@ class ScalingConfig:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return _repr_dataclass(self)
|
return _repr_dataclass(self)
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
return make_table_html_repr(obj=self, title=type(self).__name__)
|
||||||
|
|
||||||
def __eq__(self, o: "ScalingConfig") -> bool:
|
def __eq__(self, o: "ScalingConfig") -> bool:
|
||||||
if not isinstance(o, type(self)):
|
if not isinstance(o, type(self)):
|
||||||
return False
|
return False
|
||||||
|
@ -323,6 +327,11 @@ class DatasetConfig:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return _repr_dataclass(self)
|
return _repr_dataclass(self)
|
||||||
|
|
||||||
|
def _repr_html_(self, title=None) -> str:
|
||||||
|
if title is None:
|
||||||
|
title = type(self).__name__
|
||||||
|
return make_table_html_repr(obj=self, title=title)
|
||||||
|
|
||||||
def fill_defaults(self) -> "DatasetConfig":
|
def fill_defaults(self) -> "DatasetConfig":
|
||||||
"""Return a copy of this config with all default values filled in."""
|
"""Return a copy of this config with all default values filled in."""
|
||||||
return DatasetConfig(
|
return DatasetConfig(
|
||||||
|
@ -460,6 +469,28 @@ class FailureConfig:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return _repr_dataclass(self)
|
return _repr_dataclass(self)
|
||||||
|
|
||||||
|
def _repr_html_(self):
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"Tabulate isn't installed. Run "
|
||||||
|
"`pip install tabulate` for rich notebook output."
|
||||||
|
)
|
||||||
|
|
||||||
|
return Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
{
|
||||||
|
"Setting": ["Max failures", "Fail fast"],
|
||||||
|
"Value": [self.max_failures, self.fail_fast],
|
||||||
|
},
|
||||||
|
tablefmt="html",
|
||||||
|
showindex=False,
|
||||||
|
headers="keys",
|
||||||
|
),
|
||||||
|
max_height="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@PublicAPI(stability="beta")
|
@PublicAPI(stability="beta")
|
||||||
|
@ -527,6 +558,55 @@ class CheckpointConfig:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return _repr_dataclass(self)
|
return _repr_dataclass(self)
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"Tabulate isn't installed. Run "
|
||||||
|
"`pip install tabulate` for rich notebook output."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.num_to_keep is None:
|
||||||
|
num_to_keep_repr = "All"
|
||||||
|
else:
|
||||||
|
num_to_keep_repr = self.num_to_keep
|
||||||
|
|
||||||
|
if self.checkpoint_score_attribute is None:
|
||||||
|
checkpoint_score_attribute_repr = "Most recent"
|
||||||
|
else:
|
||||||
|
checkpoint_score_attribute_repr = self.checkpoint_score_attribute
|
||||||
|
|
||||||
|
if self.checkpoint_at_end is None:
|
||||||
|
checkpoint_at_end_repr = ""
|
||||||
|
else:
|
||||||
|
checkpoint_at_end_repr = self.checkpoint_at_end
|
||||||
|
|
||||||
|
return Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
{
|
||||||
|
"Setting": [
|
||||||
|
"Number of checkpoints to keep",
|
||||||
|
"Checkpoint score attribute",
|
||||||
|
"Checkpoint score order",
|
||||||
|
"Checkpoint frequency",
|
||||||
|
"Checkpoint at end",
|
||||||
|
],
|
||||||
|
"Value": [
|
||||||
|
num_to_keep_repr,
|
||||||
|
checkpoint_score_attribute_repr,
|
||||||
|
self.checkpoint_score_order,
|
||||||
|
self.checkpoint_frequency,
|
||||||
|
checkpoint_at_end_repr,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
tablefmt="html",
|
||||||
|
showindex=False,
|
||||||
|
headers="keys",
|
||||||
|
),
|
||||||
|
max_height="none",
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
|
def _tune_legacy_checkpoint_score_attr(self) -> Optional[str]:
|
||||||
"""Same as ``checkpoint_score_attr`` in ``tune.run``.
|
"""Same as ``checkpoint_score_attr`` in ``tune.run``.
|
||||||
|
@ -618,3 +698,59 @@ class RunConfig:
|
||||||
"checkpoint_config": CheckpointConfig(),
|
"checkpoint_config": CheckpointConfig(),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"Tabulate isn't installed. Run "
|
||||||
|
"`pip install tabulate` for rich notebook output."
|
||||||
|
)
|
||||||
|
|
||||||
|
reprs = []
|
||||||
|
if self.failure_config is not None:
|
||||||
|
reprs.append(
|
||||||
|
Template("title_data_mini.html.j2").render(
|
||||||
|
title="Failure Config", data=self.failure_config._repr_html_()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.sync_config is not None:
|
||||||
|
reprs.append(
|
||||||
|
Template("title_data_mini.html.j2").render(
|
||||||
|
title="Sync Config", data=self.sync_config._repr_html_()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if self.checkpoint_config is not None:
|
||||||
|
reprs.append(
|
||||||
|
Template("title_data_mini.html.j2").render(
|
||||||
|
title="Checkpoint Config", data=self.checkpoint_config._repr_html_()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a divider between each displayed repr
|
||||||
|
subconfigs = [Template("divider.html.j2").render()] * (2 * len(reprs) - 1)
|
||||||
|
subconfigs[::2] = reprs
|
||||||
|
|
||||||
|
settings = Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
{
|
||||||
|
"Name": self.name,
|
||||||
|
"Local results directory": self.local_dir,
|
||||||
|
"Verbosity": self.verbose,
|
||||||
|
"Log to file": self.log_to_file,
|
||||||
|
}.items(),
|
||||||
|
tablefmt="html",
|
||||||
|
headers=["Setting", "Value"],
|
||||||
|
showindex=False,
|
||||||
|
),
|
||||||
|
max_height="300px",
|
||||||
|
)
|
||||||
|
|
||||||
|
return Template("title_data.html.j2").render(
|
||||||
|
title="RunConfig",
|
||||||
|
data=Template("run_config.html.j2").render(
|
||||||
|
subconfigs=subconfigs,
|
||||||
|
settings=settings,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
@ -3,6 +3,7 @@ import itertools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
import html
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
Any,
|
Any,
|
||||||
|
@ -92,6 +93,7 @@ from ray.data.random_access_dataset import RandomAccessDataset
|
||||||
from ray.data.row import TableRow
|
from ray.data.row import TableRow
|
||||||
from ray.types import ObjectRef
|
from ray.types import ObjectRef
|
||||||
from ray.util.annotations import DeveloperAPI, PublicAPI
|
from ray.util.annotations import DeveloperAPI, PublicAPI
|
||||||
|
from ray.widgets import Template
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import dask
|
import dask
|
||||||
|
@ -3599,6 +3601,83 @@ class Dataset(Generic[T]):
|
||||||
else:
|
else:
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _ipython_display_(self):
|
||||||
|
try:
|
||||||
|
from ipywidgets import HTML, VBox, Layout
|
||||||
|
except ImportError:
|
||||||
|
logger.warn(
|
||||||
|
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
|
||||||
|
"enable notebook widgets."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
from IPython.display import display
|
||||||
|
|
||||||
|
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
|
||||||
|
display(VBox([title, self._tab_repr_()], layout=Layout(width="100%")))
|
||||||
|
|
||||||
|
def _tab_repr_(self):
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
from ipywidgets import Tab, HTML
|
||||||
|
except ImportError:
|
||||||
|
logger.info(
|
||||||
|
"For rich Dataset reprs in notebooks, run "
|
||||||
|
"`pip install tabulate ipywidgets`."
|
||||||
|
)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
metadata = {
|
||||||
|
"num_blocks": self._plan.initial_num_blocks(),
|
||||||
|
"num_rows": self._meta_count(),
|
||||||
|
}
|
||||||
|
|
||||||
|
schema = self.schema()
|
||||||
|
if schema is None:
|
||||||
|
schema_repr = Template("rendered_html_common.html.j2").render(
|
||||||
|
content="<h5>Unknown schema</h5>"
|
||||||
|
)
|
||||||
|
elif isinstance(schema, type):
|
||||||
|
schema_repr = Template("rendered_html_common.html.j2").render(
|
||||||
|
content=f"<h5>Data type: <code>{html.escape(str(schema))}</code></h5>"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
schema_data = {}
|
||||||
|
for sname, stype in zip(schema.names, schema.types):
|
||||||
|
schema_data[sname] = getattr(stype, "__name__", str(stype))
|
||||||
|
|
||||||
|
schema_repr = Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
tabular_data=schema_data.items(),
|
||||||
|
tablefmt="html",
|
||||||
|
showindex=False,
|
||||||
|
headers=["Name", "Type"],
|
||||||
|
),
|
||||||
|
max_height="300px",
|
||||||
|
)
|
||||||
|
|
||||||
|
tab = Tab()
|
||||||
|
children = []
|
||||||
|
|
||||||
|
tab.set_title(0, "Metadata")
|
||||||
|
children.append(
|
||||||
|
Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
tabular_data=metadata.items(),
|
||||||
|
tablefmt="html",
|
||||||
|
showindex=False,
|
||||||
|
headers=["Field", "Value"],
|
||||||
|
),
|
||||||
|
max_height="300px",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tab.set_title(1, "Schema")
|
||||||
|
children.append(schema_repr)
|
||||||
|
|
||||||
|
tab.children = [HTML(child) for child in children]
|
||||||
|
return tab
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
schema = self.schema()
|
schema = self.schema()
|
||||||
if schema is None:
|
if schema is None:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import TypeVar, Dict
|
||||||
from ray.train._internal.utils import Singleton
|
from ray.train._internal.utils import Singleton
|
||||||
from ray.train._internal.worker_group import WorkerGroup
|
from ray.train._internal.worker_group import WorkerGroup
|
||||||
from ray.util.annotations import DeveloperAPI
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
from ray.widgets import make_table_html_repr
|
||||||
|
|
||||||
EncodedData = TypeVar("EncodedData")
|
EncodedData = TypeVar("EncodedData")
|
||||||
|
|
||||||
|
@ -18,6 +19,9 @@ class BackendConfig:
|
||||||
def backend_cls(self):
|
def backend_cls(self):
|
||||||
return Backend
|
return Backend
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
return make_table_html_repr(obj=self, title=type(self).__name__)
|
||||||
|
|
||||||
|
|
||||||
@DeveloperAPI
|
@DeveloperAPI
|
||||||
class Backend(metaclass=Singleton):
|
class Backend(metaclass=Singleton):
|
||||||
|
|
|
@ -3,6 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Type, Union
|
||||||
|
from tabulate import tabulate
|
||||||
|
|
||||||
import ray
|
import ray
|
||||||
from ray import tune
|
from ray import tune
|
||||||
|
@ -19,6 +20,7 @@ from ray.train._internal.utils import construct_train_func
|
||||||
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
|
from ray.train.constants import TRAIN_DATASET_KEY, WILDCARD_KEY
|
||||||
from ray.train.trainer import BaseTrainer, GenDataset
|
from ray.train.trainer import BaseTrainer, GenDataset
|
||||||
from ray.util.annotations import DeveloperAPI
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
from ray.widgets import Template
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.data.preprocessor import Preprocessor
|
from ray.data.preprocessor import Preprocessor
|
||||||
|
@ -372,6 +374,115 @@ class DataParallelTrainer(BaseTrainer):
|
||||||
"""
|
"""
|
||||||
return self._dataset_config.copy()
|
return self._dataset_config.copy()
|
||||||
|
|
||||||
|
def _ipython_display_(self):
|
||||||
|
try:
|
||||||
|
from ipywidgets import HTML, VBox, Tab, Layout
|
||||||
|
except ImportError:
|
||||||
|
logger.warn(
|
||||||
|
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
|
||||||
|
"enable notebook widgets."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
|
||||||
|
from IPython.display import display
|
||||||
|
|
||||||
|
title = HTML(f"<h2>{self.__class__.__name__}</h2>")
|
||||||
|
|
||||||
|
tab = Tab()
|
||||||
|
children = []
|
||||||
|
|
||||||
|
tab.set_title(0, "Datasets")
|
||||||
|
children.append(self._datasets_repr_() if self.datasets else None)
|
||||||
|
|
||||||
|
tab.set_title(1, "Dataset Config")
|
||||||
|
children.append(
|
||||||
|
HTML(self._dataset_config_repr_html_()) if self._dataset_config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
tab.set_title(2, "Train Loop Config")
|
||||||
|
children.append(
|
||||||
|
HTML(self._train_loop_config_repr_html_())
|
||||||
|
if self._train_loop_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
tab.set_title(3, "Scaling Config")
|
||||||
|
children.append(
|
||||||
|
HTML(self.scaling_config._repr_html_()) if self.scaling_config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
tab.set_title(4, "Run Config")
|
||||||
|
children.append(
|
||||||
|
HTML(self.run_config._repr_html_()) if self.run_config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
tab.set_title(5, "Backend Config")
|
||||||
|
children.append(
|
||||||
|
HTML(self._backend_config._repr_html_()) if self._backend_config else None
|
||||||
|
)
|
||||||
|
|
||||||
|
tab.children = children
|
||||||
|
display(VBox([title, tab], layout=Layout(width="100%")))
|
||||||
|
|
||||||
|
def _train_loop_config_repr_html_(self) -> str:
|
||||||
|
if self._train_loop_config:
|
||||||
|
table_data = {}
|
||||||
|
for k, v in self._train_loop_config.items():
|
||||||
|
if isinstance(v, str) or str(v).isnumeric():
|
||||||
|
table_data[k] = v
|
||||||
|
elif hasattr(v, "_repr_html_"):
|
||||||
|
table_data[k] = v._repr_html_()
|
||||||
|
else:
|
||||||
|
table_data[k] = str(v)
|
||||||
|
|
||||||
|
return Template("title_data.html.j2").render(
|
||||||
|
title="Train Loop Config",
|
||||||
|
data=Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
table_data.items(),
|
||||||
|
headers=["Setting", "Value"],
|
||||||
|
showindex=False,
|
||||||
|
tablefmt="unsafehtml",
|
||||||
|
),
|
||||||
|
max_height="none",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _dataset_config_repr_html_(self) -> str:
|
||||||
|
content = []
|
||||||
|
if self._dataset_config:
|
||||||
|
for name, config in self._dataset_config.items():
|
||||||
|
content.append(
|
||||||
|
config._repr_html_(title=f"DatasetConfig - <code>{name}</code>")
|
||||||
|
)
|
||||||
|
|
||||||
|
return Template("rendered_html_common.html.j2").render(content=content)
|
||||||
|
|
||||||
|
def _datasets_repr_(self) -> str:
|
||||||
|
try:
|
||||||
|
from ipywidgets import HTML, VBox, Layout
|
||||||
|
except ImportError:
|
||||||
|
logger.warn(
|
||||||
|
"'ipywidgets' isn't installed. Run `pip install ipywidgets` to "
|
||||||
|
"enable notebook widgets."
|
||||||
|
)
|
||||||
|
return None
|
||||||
|
content = []
|
||||||
|
if self.datasets:
|
||||||
|
for name, config in self.datasets.items():
|
||||||
|
content.append(
|
||||||
|
HTML(
|
||||||
|
Template("title_data.html.j2").render(
|
||||||
|
title=f"Dataset - <code>{name}</code>", data=None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
content.append(config._tab_repr_())
|
||||||
|
|
||||||
|
return VBox(content, layout=Layout(width="100%"))
|
||||||
|
|
||||||
|
|
||||||
def _load_checkpoint(
|
def _load_checkpoint(
|
||||||
checkpoint: Checkpoint, trainer_name: str
|
checkpoint: Checkpoint, trainer_name: str
|
||||||
|
|
|
@ -30,6 +30,7 @@ from ray.tune.callback import Callback
|
||||||
from ray.tune.result import NODE_IP
|
from ray.tune.result import NODE_IP
|
||||||
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
from ray.tune.utils.file_transfer import sync_dir_between_nodes
|
||||||
from ray.util.annotations import PublicAPI, DeveloperAPI
|
from ray.util.annotations import PublicAPI, DeveloperAPI
|
||||||
|
from ray.widgets import Template
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ray.tune.experiment import Trial
|
from ray.tune.experiment import Trial
|
||||||
|
@ -93,6 +94,41 @@ class SyncConfig:
|
||||||
sync_on_checkpoint: bool = True
|
sync_on_checkpoint: bool = True
|
||||||
sync_period: int = DEFAULT_SYNC_PERIOD
|
sync_period: int = DEFAULT_SYNC_PERIOD
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
"""Generate an HTML representation of the SyncConfig.
|
||||||
|
|
||||||
|
Note that self.syncer is omitted here; seems to have some overlap
|
||||||
|
with existing configuration settings here in the SyncConfig class.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"Tabulate isn't installed. Run "
|
||||||
|
"`pip install tabulate` for rich notebook output."
|
||||||
|
)
|
||||||
|
|
||||||
|
return Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
{
|
||||||
|
"Setting": [
|
||||||
|
"Upload directory",
|
||||||
|
"Sync on checkpoint",
|
||||||
|
"Sync period",
|
||||||
|
],
|
||||||
|
"Value": [
|
||||||
|
self.upload_dir,
|
||||||
|
self.sync_on_checkpoint,
|
||||||
|
self.sync_period,
|
||||||
|
],
|
||||||
|
},
|
||||||
|
tablefmt="html",
|
||||||
|
showindex=False,
|
||||||
|
headers="keys",
|
||||||
|
),
|
||||||
|
max_height="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class _BackgroundProcess:
|
class _BackgroundProcess:
|
||||||
def __init__(self, fn: Callable):
|
def __init__(self, fn: Callable):
|
||||||
|
@ -308,6 +344,9 @@ class Syncer(abc.ABC):
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _repr_html_(self) -> str:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
class _BackgroundSyncer(Syncer):
|
class _BackgroundSyncer(Syncer):
|
||||||
"""Syncer using a background process for asynchronous file transfer."""
|
"""Syncer using a background process for asynchronous file transfer."""
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
from ray.widgets.render import Template
|
from ray.widgets.render import Template
|
||||||
|
from ray.widgets.util import make_table_html_repr
|
||||||
|
|
||||||
__all__ = ["Template"]
|
__all__ = ["Template", "make_table_html_repr"]
|
||||||
|
|
|
@ -24,6 +24,8 @@ class Template:
|
||||||
"""
|
"""
|
||||||
rendered = self.template
|
rendered = self.template
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
|
if isinstance(value, List):
|
||||||
|
value = "".join(value)
|
||||||
rendered = rendered.replace("{{ " + key + " }}", value if value else "")
|
rendered = rendered.replace("{{ " + key + " }}", value if value else "")
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
|
|
9
python/ray/widgets/templates/divider.html.j2
Normal file
9
python/ray/widgets/templates/divider.html.j2
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
<div class="vDivider"></div>
|
||||||
|
<style>
|
||||||
|
.vDivider {
|
||||||
|
border-left-width: var(--jp-border-width);
|
||||||
|
border-left-color: var(--jp-border-color0);
|
||||||
|
border-left-style: solid;
|
||||||
|
margin: 0.5em 1em 0.5em 1em;
|
||||||
|
}
|
||||||
|
</style>
|
|
@ -0,0 +1,3 @@
|
||||||
|
<div class='jp-RenderedHTMLCommon'>
|
||||||
|
{{ content }}
|
||||||
|
</div>
|
18
python/ray/widgets/templates/run_config.html.j2
Normal file
18
python/ray/widgets/templates/run_config.html.j2
Normal file
|
@ -0,0 +1,18 @@
|
||||||
|
<div class="runConfig">
|
||||||
|
<div class="generalSettings">
|
||||||
|
{{ settings }}
|
||||||
|
</div>
|
||||||
|
<div class="sideBySide">
|
||||||
|
{{ subconfigs }}
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
<style>
|
||||||
|
.sideBySide {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: row;
|
||||||
|
gap: 1em;
|
||||||
|
}
|
||||||
|
.generalSettings {
|
||||||
|
border-bottom: var(--jp-border-width) solid var(--jp-border-color0);
|
||||||
|
}
|
||||||
|
</style>
|
20
python/ray/widgets/templates/scrollableTable.html.j2
Normal file
20
python/ray/widgets/templates/scrollableTable.html.j2
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
<div class="scrollableTable jp-RenderedHTMLCommon">
|
||||||
|
{{ table }}
|
||||||
|
</div>
|
||||||
|
<style>
|
||||||
|
.scrollableTable {
|
||||||
|
overflow-y: auto;
|
||||||
|
max-height: {{ max_height }};
|
||||||
|
}
|
||||||
|
.scrollableTable table {
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
.scrollableTable table :is(th,td) {
|
||||||
|
text-align: left !important;
|
||||||
|
}
|
||||||
|
.scrollableTable th {
|
||||||
|
background: var(--jp-layout-color1);
|
||||||
|
position: sticky;
|
||||||
|
top: 0;
|
||||||
|
}
|
||||||
|
</style>
|
11
python/ray/widgets/templates/title_data.html.j2
Normal file
11
python/ray/widgets/templates/title_data.html.j2
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
<div class='titleData jp-RenderedHTMLCommon'>
|
||||||
|
<h3>{{ title }}</h3>
|
||||||
|
{{ data }}
|
||||||
|
</div>
|
||||||
|
<style>
|
||||||
|
.titleData h3 {
|
||||||
|
border-bottom-width: var(--jp-border-width);
|
||||||
|
border-bottom-color: var(--jp-border-color0);
|
||||||
|
border-bottom-style: solid;
|
||||||
|
}
|
||||||
|
</style>
|
4
python/ray/widgets/templates/title_data_mini.html.j2
Normal file
4
python/ray/widgets/templates/title_data_mini.html.j2
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
<div class='miniTitleData jp-RenderedHTMLCommon'>
|
||||||
|
<h4><b>{{ title }}</b></h4>
|
||||||
|
{{ data }}
|
||||||
|
</div>
|
61
python/ray/widgets/util.py
Normal file
61
python/ray/widgets/util.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from ray.util.annotations import DeveloperAPI
|
||||||
|
from ray.widgets import Template
|
||||||
|
|
||||||
|
|
||||||
|
@DeveloperAPI
|
||||||
|
def make_table_html_repr(
|
||||||
|
obj: Any, title: Optional[str] = None, max_height: str = "none"
|
||||||
|
) -> str:
|
||||||
|
"""Generate a generic html repr using a table.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: Object for which a repr is to be generated
|
||||||
|
title: If present, a title for the section is included
|
||||||
|
max_height: Maximum height of the table; valid values
|
||||||
|
are given by the max-height CSS property
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML representation of the object
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from tabulate import tabulate
|
||||||
|
except ImportError:
|
||||||
|
return (
|
||||||
|
"Tabulate isn't installed. Run "
|
||||||
|
"`pip install tabulate` for rich notebook output."
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {}
|
||||||
|
for k, v in vars(obj).items():
|
||||||
|
if isinstance(v, (str, bool, int, float)):
|
||||||
|
data[k] = str(v)
|
||||||
|
|
||||||
|
elif isinstance(v, dict) or hasattr(v, "__dict__"):
|
||||||
|
data[k] = Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
v.items() if isinstance(v, dict) else vars(v).items(),
|
||||||
|
tablefmt="html",
|
||||||
|
showindex=False,
|
||||||
|
headers=["Setting", "Value"],
|
||||||
|
),
|
||||||
|
max_height="none",
|
||||||
|
)
|
||||||
|
|
||||||
|
table = Template("scrollableTable.html.j2").render(
|
||||||
|
table=tabulate(
|
||||||
|
data.items(),
|
||||||
|
tablefmt="unsafehtml",
|
||||||
|
showindex=False,
|
||||||
|
headers=["Setting", "Value"],
|
||||||
|
),
|
||||||
|
max_height=max_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
if title:
|
||||||
|
content = Template("title_data.html.j2").render(title=title, data=table)
|
||||||
|
else:
|
||||||
|
content = table
|
||||||
|
|
||||||
|
return content
|
|
@ -4,6 +4,7 @@ import glob
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
|
@ -194,10 +195,9 @@ ray_files += [
|
||||||
for filename in filenames
|
for filename in filenames
|
||||||
]
|
]
|
||||||
|
|
||||||
# Files for ray.init html template.
|
# html templates for notebook integration
|
||||||
ray_files += [
|
ray_files += [
|
||||||
"ray/widgets/templates/context_dashrow.html.j2",
|
p.as_posix() for p in pathlib.Path("ray/widgets/templates/").glob("*.html.j2")
|
||||||
"ray/widgets/templates/context.html.j2",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
# If you're adding dependencies for ray extras, please
|
# If you're adding dependencies for ray extras, please
|
||||||
|
|
Loading…
Add table
Reference in a new issue