Is there a fabric teardown method? (memory leak?)

Hello,

I am trying to instantiate multiple Fabric instances in a loop (one per iteration). However, after every iteration, my memory consumption goes up. Is there some fabric.teardown() like method? I couldn’t find anything in the documentation.

I tried to make a self-contained example of my problem (derived from lit-gpt code but simplified to illustrate the problem). Parameters are just chosen to show the problem (using large enough network). devices set to 1 for debugging but also shows up for larger number.

import shutil
import warnings
from pathlib import Path
from types import SimpleNamespace

import lightning as L
import torch
import torch.nn as nn
import wandb
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies import DeepSpeedStrategy


def main():
    config = SimpleNamespace(
        input_size=1024,
        output_size=1024,
        hidden_size=8192,
        num_layers=5,
        devices=1,
        ds_config=None,
        precision="bf16-true",
        entity="x",
        wandb_project="project",
        run_name="memory_test",
        checkpoint_path=None,
        learning_rate=1e-4,
        weight_decay=1e-5,
        max_steps=1000,
        batch_size=64,
        log_interval=1,
        out_dir=Path("./out"),
    )

    batch_size_per_device = config.batch_size / config.devices
    micro_batch_size = 1
    config.gradient_accumulation_steps = int(batch_size_per_device // micro_batch_size)

    for epoch in range(3):
        config.checkpoint_path = run_epoch(config)


def run_epoch(config):
    # Load Fabric
    fabric = L.Fabric(
        devices=config.devices,
        strategy=(
            DeepSpeedStrategy(config=config.ds_config) if config.devices > 1 else "auto"
        ),
        precision=config.precision,
    )
    fabric.launch()
    fabric.seed_everything(1337 + fabric.global_rank)

    if fabric.global_rank == 0:
        wandb.init(
            entity=config.entity,
            project=config.wandb_project,
            name=config.run_name,
            config=config,
        )

    with fabric.device:
        torch.set_default_tensor_type(torch.HalfTensor)
        model = LargeMODEL(config).bfloat16()
        torch.set_default_tensor_type(torch.FloatTensor)
        # Load checkpoint if this is not the first epoch
        if config.checkpoint_path is not None:
            checkpoint_path = config.checkpoint_path
            checkpoint = torch.load(checkpoint_path)
            model.load_state_dict(checkpoint, strict=False)

    # Setup model and optimizer in fabric
    optimizer = torch.optim.AdamW(
        model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay
    )
    model, optimizer = fabric.setup(model, optimizer)

    train_data = load_datasets(config)

    # Train the model
    train(
        config,
        fabric,
        model,
        optimizer,
        train_data,
    )

    # Save the final checkpoint at the end of training
    save_path = config.out_dir / "model_trained.pth"
    fabric.print(f"Saving weights to {str(save_path)!r}")
    print(save_path)
    save_model_checkpoint(fabric, model, save_path)
    return save_path


def train(config, fabric, model, optimizer, train_data):
    step_count = 0
    max_iters = int(config.max_steps * config.gradient_accumulation_steps)
    for iter_num in range(max_iters):
        input_ids, targets = get_batch(fabric, train_data, config.batch_size)
        with fabric.no_backward_sync(
            model, enabled=((iter_num + 1) % config.gradient_accumulation_steps != 0)
        ):
            logits = model(input_ids)
            loss = loss_fn(logits, targets)
            fabric.backward(loss / config.gradient_accumulation_steps)
        if (iter_num + 1) % config.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
            step_count += 1
            fabric.call("on_train_batch_end", model=model)

        # Report performance to command line and wandb
        if step_count % config.log_interval == 0:
            if fabric.global_rank == 0:
                wandb.log(
                    {
                        "iter": iter_num,
                        "step": step_count,
                        "train/loss": loss,
                        "train/lr": optimizer.param_groups[0]["lr"],
                    }
                )


def save_model_checkpoint(fabric, model, file_path: Path):
    file_path = Path(file_path)
    # Ensure the directory exists
    file_path.parent.mkdir(parents=True, exist_ok=True)

    if isinstance(fabric.strategy, DeepSpeedStrategy):
        from deepspeed.utils.zero_to_fp32 import (
            get_fp32_state_dict_from_zero_checkpoint,
        )

        tmp_path = file_path.with_suffix(".tmp")
        fabric.save(tmp_path, {"model": model})
        fabric.barrier()
        if fabric.global_rank == 0:
            state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
            torch.save(state_dict, file_path)
            shutil.rmtree(tmp_path)
    else:
        if fabric.global_rank == 0:
            state_dict = model.state_dict()
            torch.save(state_dict, file_path)
        fabric.barrier()

class LargeMODEL(nn.Module):
    def __init__(self, config):
        super(LargeMODEL, self).__init__()

        layers = []
        input_size = config.input_size
        for _ in range(config.num_layers):
            layers.append(nn.Linear(input_size, config.hidden_size))
            layers.append(nn.ReLU())
            input_size = config.hidden_size

        layers.append(nn.Linear(config.hidden_size, config.output_size))

        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


def load_datasets(config):
    # For the sake of this demonstration, let's return some random data
    data = torch.randn(1000, config.input_size), torch.randn(1000, config.output_size)
    return data


def get_batch(fabric, data, batch_size):
    input_ids, targets = data
    ix = torch.randint(len(data), (batch_size,))
    x, y = input_ids[ix], targets[ix]
    # Return the first batch_size samples from the dataset
    if isinstance(fabric.accelerator, MPSAccelerator):
        x, y = fabric.to_device((x, y))
    else:
        x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
    return x, y


def loss_fn(logits, targets):
    return ((logits - targets) ** 2).mean()


if __name__ == "__main__":
    # Uncomment this line if you see an error:
    # "Expected is_sm80 to be true, but got false"
    # torch.backends.cuda.enable_flash_sdp(False)
    torch.set_float32_matmul_precision("high")

    # from jsonargparse.cli import CLI
    warnings.filterwarnings(
        # false positive using deepspeed:
        # https://github.com/Lightning-AI/lightning/pull/17761#discussion_r1219705307
        "ignore",
        message="Remove `.no_backward_sync()` from your code",
    )
    main()

Here memory use for this example run (3 loops over initializing fabric coinciding with the three memory spikes):
Screenshot 2023-08-21 at 11.38.25 AM

Does anyone know how to fix this?

Thanks a lot for any help!

Hi @vkakerbeck

There is no teardown method for this, but there is a feature proposal for that: Provide teardown APIs in Fabric · Issue #14682 · Lightning-AI/lightning · GitHub

Instantiating the fabric object in each epoch should not be a problem, but it should also be garbage collected since it is only referenced within the run_epoch function. So I can’t explain why it happens. My guess is that even with a teardown method, your memory increase would persist. I suggest to open a GH bug report issue for this if you don’t mind.

Thanks for sharing!

1 Like

Hi @awaelchli ,

thanks for your quick reply! I opened a github issue here: Memory Leak when instantiating Fabric multiple times · Issue #18356 · Lightning-AI/lightning · GitHub

Best wishes,
Viviane

1 Like