Hello,
I am trying to instantiate multiple Fabric instances in a loop (one per iteration). However, after every iteration, my memory consumption goes up. Is there some fabric.teardown() like method? I couldn’t find anything in the documentation.
I tried to make a self-contained example of my problem (derived from lit-gpt code but simplified to illustrate the problem). Parameters are just chosen to show the problem (using large enough network). devices set to 1 for debugging but also shows up for larger number.
import shutil
import warnings
from pathlib import Path
from types import SimpleNamespace
import lightning as L
import torch
import torch.nn as nn
import wandb
from lightning.fabric.accelerators.mps import MPSAccelerator
from lightning.fabric.strategies import DeepSpeedStrategy
def main():
config = SimpleNamespace(
input_size=1024,
output_size=1024,
hidden_size=8192,
num_layers=5,
devices=1,
ds_config=None,
precision="bf16-true",
entity="x",
wandb_project="project",
run_name="memory_test",
checkpoint_path=None,
learning_rate=1e-4,
weight_decay=1e-5,
max_steps=1000,
batch_size=64,
log_interval=1,
out_dir=Path("./out"),
)
batch_size_per_device = config.batch_size / config.devices
micro_batch_size = 1
config.gradient_accumulation_steps = int(batch_size_per_device // micro_batch_size)
for epoch in range(3):
config.checkpoint_path = run_epoch(config)
def run_epoch(config):
# Load Fabric
fabric = L.Fabric(
devices=config.devices,
strategy=(
DeepSpeedStrategy(config=config.ds_config) if config.devices > 1 else "auto"
),
precision=config.precision,
)
fabric.launch()
fabric.seed_everything(1337 + fabric.global_rank)
if fabric.global_rank == 0:
wandb.init(
entity=config.entity,
project=config.wandb_project,
name=config.run_name,
config=config,
)
with fabric.device:
torch.set_default_tensor_type(torch.HalfTensor)
model = LargeMODEL(config).bfloat16()
torch.set_default_tensor_type(torch.FloatTensor)
# Load checkpoint if this is not the first epoch
if config.checkpoint_path is not None:
checkpoint_path = config.checkpoint_path
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint, strict=False)
# Setup model and optimizer in fabric
optimizer = torch.optim.AdamW(
model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay
)
model, optimizer = fabric.setup(model, optimizer)
train_data = load_datasets(config)
# Train the model
train(
config,
fabric,
model,
optimizer,
train_data,
)
# Save the final checkpoint at the end of training
save_path = config.out_dir / "model_trained.pth"
fabric.print(f"Saving weights to {str(save_path)!r}")
print(save_path)
save_model_checkpoint(fabric, model, save_path)
return save_path
def train(config, fabric, model, optimizer, train_data):
step_count = 0
max_iters = int(config.max_steps * config.gradient_accumulation_steps)
for iter_num in range(max_iters):
input_ids, targets = get_batch(fabric, train_data, config.batch_size)
with fabric.no_backward_sync(
model, enabled=((iter_num + 1) % config.gradient_accumulation_steps != 0)
):
logits = model(input_ids)
loss = loss_fn(logits, targets)
fabric.backward(loss / config.gradient_accumulation_steps)
if (iter_num + 1) % config.gradient_accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
step_count += 1
fabric.call("on_train_batch_end", model=model)
# Report performance to command line and wandb
if step_count % config.log_interval == 0:
if fabric.global_rank == 0:
wandb.log(
{
"iter": iter_num,
"step": step_count,
"train/loss": loss,
"train/lr": optimizer.param_groups[0]["lr"],
}
)
def save_model_checkpoint(fabric, model, file_path: Path):
file_path = Path(file_path)
# Ensure the directory exists
file_path.parent.mkdir(parents=True, exist_ok=True)
if isinstance(fabric.strategy, DeepSpeedStrategy):
from deepspeed.utils.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint,
)
tmp_path = file_path.with_suffix(".tmp")
fabric.save(tmp_path, {"model": model})
fabric.barrier()
if fabric.global_rank == 0:
state_dict = get_fp32_state_dict_from_zero_checkpoint(tmp_path)
torch.save(state_dict, file_path)
shutil.rmtree(tmp_path)
else:
if fabric.global_rank == 0:
state_dict = model.state_dict()
torch.save(state_dict, file_path)
fabric.barrier()
class LargeMODEL(nn.Module):
def __init__(self, config):
super(LargeMODEL, self).__init__()
layers = []
input_size = config.input_size
for _ in range(config.num_layers):
layers.append(nn.Linear(input_size, config.hidden_size))
layers.append(nn.ReLU())
input_size = config.hidden_size
layers.append(nn.Linear(config.hidden_size, config.output_size))
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)
def load_datasets(config):
# For the sake of this demonstration, let's return some random data
data = torch.randn(1000, config.input_size), torch.randn(1000, config.output_size)
return data
def get_batch(fabric, data, batch_size):
input_ids, targets = data
ix = torch.randint(len(data), (batch_size,))
x, y = input_ids[ix], targets[ix]
# Return the first batch_size samples from the dataset
if isinstance(fabric.accelerator, MPSAccelerator):
x, y = fabric.to_device((x, y))
else:
x, y = fabric.to_device((x.pin_memory(), y.pin_memory()))
return x, y
def loss_fn(logits, targets):
return ((logits - targets) ** 2).mean()
if __name__ == "__main__":
# Uncomment this line if you see an error:
# "Expected is_sm80 to be true, but got false"
# torch.backends.cuda.enable_flash_sdp(False)
torch.set_float32_matmul_precision("high")
# from jsonargparse.cli import CLI
warnings.filterwarnings(
# false positive using deepspeed:
# https://github.com/Lightning-AI/lightning/pull/17761#discussion_r1219705307
"ignore",
message="Remove `.no_backward_sync()` from your code",
)
main()
Here memory use for this example run (3 loops over initializing fabric coinciding with the three memory spikes):
Does anyone know how to fix this?
Thanks a lot for any help!