Source code for pytorch_lightning.callbacks.rich_model_summary
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromtypingimportList,Tuplefrompytorch_lightning.callbacksimportModelSummaryfrompytorch_lightning.utilities.importsimport_RICH_AVAILABLEfrompytorch_lightning.utilities.model_summaryimportget_human_readable_countif_RICH_AVAILABLE:fromrichimportget_consolefromrich.tableimportTable
[docs]classRichModelSummary(ModelSummary):r""" Generates a summary of all layers in a :class:`~pytorch_lightning.core.module.LightningModule` with `rich text formatting <https://github.com/Textualize/rich>`_. Install it with pip: .. code-block:: bash pip install rich .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichModelSummary trainer = Trainer(callbacks=RichModelSummary()) You could also enable ``RichModelSummary`` using the :class:`~pytorch_lightning.callbacks.RichProgressBar` .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.callbacks import RichProgressBar trainer = Trainer(callbacks=RichProgressBar()) Args: max_depth: The maximum depth of layer nesting that the summary will include. A value of 0 turns the layer summary off. Raises: ModuleNotFoundError: If required `rich` package is not installed on the device. """def__init__(self,max_depth:int=1)->None:ifnot_RICH_AVAILABLE:raiseModuleNotFoundError("`RichModelSummary` requires `rich` to be installed. Install it by running `pip install -U rich`.")super().__init__(max_depth)@staticmethoddefsummarize(summary_data:List[Tuple[str,List[str]]],total_parameters:int,trainable_parameters:int,model_size:float,)->None:console=get_console()table=Table(header_style="bold magenta")table.add_column(" ",style="dim")table.add_column("Name",justify="left",no_wrap=True)table.add_column("Type")table.add_column("Params",justify="right")column_names=list(zip(*summary_data))[0]forcolumn_namein["In sizes","Out sizes"]:ifcolumn_nameincolumn_names:table.add_column(column_name,justify="right",style="white")rows=list(zip(*(arr[1]forarrinsummary_data)))forrowinrows:table.add_row(*row)console.print(table)parameters=[]forparamin[trainable_parameters,total_parameters-trainable_parameters,total_parameters,model_size]:parameters.append("{:<{}}".format(get_human_readable_count(int(param)),10))grid=Table.grid(expand=True)grid.add_column()grid.add_column()grid.add_row(f"[bold]Trainable params[/]: {parameters[0]}")grid.add_row(f"[bold]Non-trainable params[/]: {parameters[1]}")grid.add_row(f"[bold]Total params[/]: {parameters[2]}")grid.add_row(f"[bold]Total estimated model params size (MB)[/]: {parameters[3]}")console.print(grid)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.