Lightning AI Studios: Never set up a local environment again →

← Back to blog

Accelerating LLaMA with Fabric: A Comprehensive Guide to Training and Fine-Tuning LLaMA

Takeaways

In this tutorial, we will learn how to train and fine-tune LLaMA (Large Language Model Meta AI). Lit-LLaMA, a rewrite of LLaMA, can run inference on an 8 GB consumer GPU. We will also discover how it utilizes Lightning Fabric to accelerate the PyTorch code.

What is LLaMA 🦙

LLaMA is a foundational large language model that has been released by Meta AI.

LLaMA comes in four size variants: 7B, 13B, 33B, and 65B parameters. The paper shows that training smaller foundation models on large enough tokens is desirable, as it requires less computing power and resources. The 65B parameter models have been trained on 1.4 trillion tokens, while the LLaMA 7B model has been trained on 1 trillion tokens.

Just a few weeks after the release of LLaMA, the open-source community embraced it by creating an optimized version and expanding its use cases. Now, you can fine-tune LLaMA using LoRA (reduces the number of trainable parameters for fine-tuning) and train a chatbot with Stanford Alpaca.

Lightning AI has also joined the trend by providing an open-source, from-scratch rewrite of LLaMA called Lit-LLaMA. The main highlight of Lit-LLaMA is that it is released under the Apache 2.0 license, which makes it easier to adopt for other deep learning projects that use similar permissive licenses and also enables commercial use. It has scripts for optimized training and fine-tuning with LoRA.

Lit-LLaMA: simple, optimized, and completely open-source 🔥

 

Lit-LLaMA is a scratch rewrite of LLaMA that uses Lightning Fabric for scaling PyTorch code. It focuses on code readability and optimizations to run on consumer GPUs. As of the time of writing this article, you can run Lit-LLaMA on GPUs with 8 GB of memory 🤯.

Note: Currently you need to download the official Meta AI LLaMA pre-trained model weights for fine-tuning or running inference.

Lit-LLaMA supports training, fine-tuning, and generating inference. Let’s discuss each functionality in detail.

Training LLaMA

Note: We won’t go into too many details about training LLaMA from scratch and instead focus more on fine-tuning and inference because the computational need for training is not available to everyone in the community.

The repo comes with a simple and readable LLaMA model implementation and a training script accelerated by Fabric.

Large language models may not fit into a single GPU. Fully Sharded Data Parallelism (FSDP) is a technique that shards model parameters, gradients, and optimizer states across data parallel workers. Fabric provides a unified API that makes it easy to use FSDP.

To use FSDP (Fully-Sharded Data Parallel) with Fabric, create an FSDPStrategy object by specifying the auto-wrap policy and passing it as an argument to the Fabric class.

Fabric helps to automatically place the model and tensors on the correct devices, enabling distributed training, mixed precision, and the ability to select the number of devices to train on.



import lightning as L

from lightning.fabric.strategies import FSDPStrategy

import torch

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy from lit_llama.model import Block, LLaMA, LLaMAConfig

def main():

    # ⚡️⚡️⚡️⚡️⚡️ Initialize FSDP strategy ⚡️⚡️⚡️⚡️⚡️

    auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={Block})

    strategy = FSDPStrategy(auto_wrap_policy=auto_wrap_policy, activation_checkpointing=Block)

    # ⚡️⚡️⚡️⚡️⚡️ Initialize Fabric ⚡️⚡️⚡️⚡️⚡️

    # setting for 4 GPUs with bf16 mixed precision and FSDP distributed training strategy

    fabric = L.Fabric(accelerator="cuda", devices=4, precision="bf16-mixed", strategy=strategy)

    fabric.launch()

    # Load data

    train_data, val_data = load_datasets()

    # Load model configs

    config = LLaMAConfig.from_name("7B")

    config.block_size = block_size

    config.vocab_size = 100  # from prepare_shakespeare.py

    # ⚡️⚡️⚡️⚡️⚡️ initialize model ⚡️⚡️⚡️⚡️⚡️

    with fabric.device:

        model = LLaMA(config)

    # ⚡️⚡️⚡️⚡️⚡️ Setup model and optimizer for distributed training ⚡️⚡️⚡️⚡️⚡️

    model = fabric.setup_module(model)

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(beta1, beta2))

    optimizer = fabric.setup_optimizers(optimizer)

    train(fabric, model, optimizer, train_data, val_data)

You can also find the full training code here.

Fine-tuning LLaMA

Within just weeks of launching LLaMA, the community began optimizing and building upon it. Fine-tuning LLaMA on consumer GPUs is crucial to truly democratize LLMs. LLMs can be fine-tuned to build a chatbot and specialized for particular tasks or fields, such as an LLM specialized in summarizing legal or financial data.

Lit-LLaMA includes a fine-tuning script that utilizes LoRA (Low-Rank Adaptation of Large Language Models). LoRA freezes the pre-trained model weights and injects trainable rank decomposition matrices into each layer of the Transformer architecture. This significantly reduces the number of trainable parameters for downstream tasks.

To initialize LLaMA with LoRA layers we need to use the context manager from lit_llama.lora:



from lit_llama.lora import lora

from lit_llama.model import LLaMA, LLaMAConfig

# initialize configs

lora_dropout = 0.05

config = LLaMAConfig.from_name("7B")

config.block_size = block_size

# initlize model with LoRA

with lora(r=lora_r, alpha=lora_alpha, dropout=lora_dropout, enabled=True):

        model = LLaMA(config)

# mark only LoRA injected layers for training

mark_only_lora_as_trainable(model)

You can notice that the with lora is a Python context manager which implements the replacement of CausalSelfAttention with the LoRA-injected trainable parameters.



@contextmanager

def lora(r, alpha, dropout, enabled: bool = True):

    """A context manager under which you can instantiate the model with LoRA."""

    if not enabled:

        yield

        return

    LoRACausalSelfAttention.lora_config = LoRAConfig(r=r, alpha=alpha, dropout=dropout)

    causal_self_attention = llama.CausalSelfAttention

    llama.CausalSelfAttention = LoRACausalSelfAttention

    yield

    llama.CausalSelfAttention = causal_self_attention

    LoRACausalSelfAttention.lora_config = None

We are now ready to fine-tune LLaMA. You can fine-tune Lit-LLaMA on the Alpaca dataset using LoRA and quantization on a consumer GPU. Lit-LLaMA comes with a simple script for downloading and preparing the Alpaca dataset, which you can find here.

Note: You can convert the official Meta AI LLaMA weights to Lit-LLaMA format using the instructions here.

Follow these two simple steps to instruction-tune LLaMA:

  1. Download data and generate instruction tuning dataset: python scripts/prepare_alpaca.py
  2. Run the fine-tuning script: python finetune_lora.py

Find the full fine-tuning code here.

Generating text from a trained model

To generate text predictions, you will need trained model weights. You can use either the official Meta AI weights or the model that you have fine-tuned. Lit-LLaMA includes a text-generation script that can run on a GPU with 8 GB of memory and quantization. To generate text, run the following command in the terminal:

python generate.py --quantize true --prompt "Here's what people think about pineapple pizza: "

Conclusion

Lit-LLaMA promotes open and collective science by releasing its source code under the Apache 2.0 license. It extends the original Meta AI code by adding training, an optimized fine-tuning script, and the ability to run inference on a consumer GPU (using up to 8GB of memory with quantization). Lit-LLaMA has already crossed 2K GitHub stars 💫, and it will be interesting to see what the community builds on top of it.

Go to Repo