Thunder bindings for Liger operators

In this notebook we explore Thunder Bindings for Liger Operators.

It is based on Episode 10 of the Thunder Sessions podcast.

Let’s import things.

[1]:
from collections.abc import Sequence
import math

import torch
from torch.testing import assert_close
import litgpt
import thunder
from thunder.core.proxies import TensorProxy, AnyProxy
from thunder.core.transforms import get_grad, put_grads
from thunder.torch import TensorLike
import thunder.extend

import liger_kernel.ops.rms_norm
import liger_kernel.ops.rope
import liger_kernel.ops.swiglu
import liger_kernel.ops.geglu  # TODO
import liger_kernel.ops.cross_entropy  # TODO
import liger_kernel.ops.fused_linear_cross_entropy

device = torch.device("cuda")

We define and register an executor.

[2]:
liger_ex = thunder.extend.OperatorExecutor("liger", version="0.1")
thunder.extend.register_executor(liger_ex)
[2]:
thunder.extend.OperatorExecutor('liger')

RMS Norm

The first thing to fuse is RMS Norm.

After that, Liger’s implementation is a drop-in replacement. We define operators for forward and backward and then a gradient and execution rule.

We register these as an implementation for the rms_norm operand that we divert the PyTorch function to.

[3]:
# A tiny detail here is that PyTorch gained a `rms_norm` function somewhat
# recently and we need to tell LitGPT to use it.


def RMSNorm_forward(self, x):
    return torch.nn.functional.rms_norm(x, self.weight.shape, self.weight, self.eps)


litgpt.model.RMSNorm.forward = RMSNorm_forward
[4]:
import functools

prod = lambda *args: functools.reduce(lambda x, y: x * y, args)
[5]:
# ******************************* RMS NORM *******************************
import functools


def liger_rms_norm_forward_meta(X, W, eps, offset, casting_mode):
    *n_rows, n_cols = X.shape
    n_rows = prod(*n_rows)
    # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
    rstd_dtype = (
        thunder.dtypes.float32
        if casting_mode
        in (liger_kernel.ops.rms_norm._CASTING_MODE_LLAMA.value, liger_kernel.ops.rms_norm._CASTING_MODE_GEMMA.value)
        else X.dtype
    )
    Y = TensorProxy(like=X)
    RSTD = TensorProxy(like=X, shape=(n_rows,), dtype=rstd_dtype)
    BLOCK_SIZE, num_warps = liger_kernel.ops.rms_norm.calculate_settings(n_cols)
    return Y, TensorProxy(like=X, shape=(n_rows, n_cols)), RSTD, BLOCK_SIZE, num_warps, casting_mode


liger_rms_norm_forward = liger_ex.register_operator(
    "liger_rms_norm_forward", meta=liger_rms_norm_forward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_forward
)


def liger_rms_norm_backward_meta(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
    return TensorProxy(like=X), TensorProxy(like=W)


liger_rms_norm_backward = liger_ex.register_operator(
    "liger_rms_norm_backward", meta=liger_rms_norm_backward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_backward
)


def rms_norm_meta(x, shape, w, eps):
    return thunder.TensorProxy(like=x)


rms_norm = liger_ex.register_operator(
    "rms_norm", meta=rms_norm_meta, fn=torch.nn.functional.rms_norm, replaces=torch.nn.functional.rms_norm
)


def rms_norm_grad_transform(x, shape, weight, eps):
    Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = liger_rms_norm_forward(
        x, weight, eps, offset=0.0, casting_mode="llama"
    )
    dY = get_grad(Y)
    dX, dW = liger_rms_norm_backward(
        dY, X, weight, RSTD, offset=0.0, casting_mode="llama", BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
    )
    dX = dX.view(*x.shape)
    put_grads((x, weight), (dX, dW))
    return Y


def rms_norm_execution_transform(x, weight, eps):
    Y, *_ = liger_rms_norm_forward(x, weight, eps, offset=0.0, casting_mode="llama")
    return Y


liger_ex.register_implementation(
    rms_norm, execution_transform=rms_norm_execution_transform, grad_transform=rms_norm_grad_transform
)

Testing RMS Norm

Let’s test.

[6]:
hidden_size = 64

example_input = torch.randn(32, 10, hidden_size, device=device, requires_grad=True)

with device:
    model = litgpt.model.RMSNorm(hidden_size)
thunder_model = thunder.jit(model, executors=[liger_ex])
ref = model(example_input.clone())
res = thunder_model(example_input.clone())
go = torch.randn_like(ref)
grad_ref, grad_ref_weight = torch.autograd.grad(ref, (example_input, model.weight), go)
grad_res, grad_res_weight = torch.autograd.grad(res, (example_input, model.weight), go)


assert liger_rms_norm_forward in {bsym.sym for bsym in thunder.last_traces(thunder_model)[-1].bound_symbols}
assert liger_rms_norm_backward in {bsym.sym for bsym in thunder.last_backward_traces(thunder_model)[-1].bound_symbols}

assert_close(ref, res)
assert_close(grad_ref, grad_res)
assert_close(grad_ref_weight, grad_res_weight)
[ ]:

RoPE

Next is the RoPE implementation. Liger does both rope applications to query and key in one kernel whereas LitGPT uses two. So we define not only forward and backward and a symbol to capture the litgpt version, but also a small transform fusing the two apply_rope calls to one liger_rope.

[7]:
def liger_rope_forward_meta(q, k, cos, sin):
    return TensorProxy(like=q), TensorProxy(like=k), cos, sin


liger_rope_forward = liger_ex.register_operator(
    "liger_rope_forward",
    meta=liger_rope_forward_meta,
    fn=liger_kernel.ops.rope.rope_forward,
)


def liger_rope_backward_meta(dq, dk, cos, sin):
    return TensorLike(like=dq), TensorLike(like=dk)


liger_rope_backward = liger_ex.register_operator(
    "liger_rope_backward",
    meta=liger_rope_backward_meta,
    fn=liger_kernel.ops.rope.rope_backward,
)


def liger_rope_grad_transform(q, k, cos, sin):
    q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)
    q_out_grad = get_grad(q_out)
    k_out_grad = get_grad(k_out)
    dq, dk = liger_rope_backward(q_out_grad, k_out_grad, cos, sin)
    put_grads((q, k), (dq, dk))
    return q_out, k_out


def liger_rope_execution_transform(q, k, cos, sin):
    q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)
    return q_out, k_out


def liger_rope_impl(q, k, cos, sin):
    qr, kr, _, _ = liger_rope_forward(q, k, cos, sin)
    return qr, kr


liger_rope = liger_ex.register_operator("liger_rope", fn=liger_rope_impl, like=liger_rope_impl)

liger_ex.register_implementation(
    liger_rope,
    execution_transform=liger_rope_execution_transform,
    grad_transform=liger_rope_grad_transform,
)


def litgpt_apply_rope_meta(x, cos, sin):
    return TensorProxy(like=x)


litgpt_apply_rope = liger_ex.register_operator(
    "litgpt_apply_rope", fn=litgpt.model.apply_rope, meta=litgpt_apply_rope_meta, replaces=litgpt.model.apply_rope
)


class MergeRopeTransform(thunder.core.transform_common.Transform):
    def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):
        new_compute_trace = thunder.core.trace.from_trace(compute_trace)
        bound_symbols = compute_trace.bound_symbols[:]
        while bound_symbols:
            bsym = bound_symbols.pop(0)
            if bsym.sym == litgpt_apply_rope:
                for i, bsym2 in enumerate(bound_symbols):
                    assert not any(o is bsym.output for o in bsym2.flat_outs)
                    if bsym2.sym == litgpt_apply_rope:
                        break
                bsym2 = bound_symbols.pop(i)
                assert bsym2.sym == litgpt_apply_rope

                output = (bsym.output, bsym2.output)
                args = (bsym.args[0], bsym2.args[0], *bsym.args[1:])

                new_compute_trace.bound_symbols.append(bsym.from_bsym(args=args, output=output, sym=liger_rope))
            else:
                new_compute_trace.bound_symbols.append(bsym.from_bsym())
        new_compute_trace.set_provenance(thunder.core.trace.TraceProvenance(self.__class__))
        return prologue_trace, new_compute_trace, epilogue_trace

Test

We test with a scaled-down Llama.

[8]:
cfg = litgpt.Config.from_name("Llama-3.2-1B", n_layer=1)
with device:
    m = litgpt.GPT(cfg)
    m.max_seq_length = 1024
    m.set_kv_cache(1)
    inp = torch.arange(1, 6, dtype=torch.int64)[None]
    inp_pos = torch.arange(1, 6, dtype=torch.int64)


jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))
res = jm(inp, inp_pos)

go = torch.randn_like(res)
(grad_res,) = torch.autograd.grad(res, jm.get_parameter("transformer.wte.weight"), go)
ref = m(inp, inp_pos)
(grad_ref,) = torch.autograd.grad(ref, m.get_parameter("transformer.wte.weight"), go)

assert_close(res, ref)
assert_close(grad_res, grad_ref)

assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)

SwiGLU

[9]:
def liger_swiglu_forward_meta(a, b):
    return TensorProxy(like=a)


def liger_swiglu_forward_impl(a, b):
    _, _, res = liger_kernel.ops.swiglu.swiglu_forward(a, b)
    return res


liger_swiglu_forward = liger_ex.register_operator(
    "liger_swiglu_forward",
    meta=liger_swiglu_forward_meta,
    fn=liger_swiglu_forward_impl,
)


def liger_swiglu_backward_meta(a, b, grad_res):
    return TensorProxy(like=a), TensorProxy(like=b)


liger_swiglu_backward = liger_ex.register_operator(
    "liger_swiglu_backward",
    meta=liger_swiglu_backward_meta,
    fn=liger_kernel.ops.swiglu.swiglu_backward,
)


def liger_swiglu_gradient_transform(a, b):
    res = liger_swiglu_forward(a, b)
    grad_res = get_grad(res)
    grad_a, grad_b = liger_swiglu_backward(a, b, grad_res)
    put_grads((a, b), (grad_a, grad_b))
    return res


liger_ex.register_implementation(
    liger_swiglu_forward, grad_transform=liger_swiglu_gradient_transform, execution_transform=liger_swiglu_forward
)


class FuseSwigLUTransform(thunder.core.transform_common.Transform):
    def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
        _, consumers = thunder.core.utils.producers_and_consumers(computation_trace)
        new_computation_trace = thunder.core.trace.from_trace(computation_trace)
        bsyms_to_skip = set()
        for b in computation_trace.bound_symbols:
            if b in bsyms_to_skip:
                continue
            new_bsym = b
            if b.sym == thunder.torch.silu:
                c = consumers[b.output]
                if len(c) == 1 and c[0].sym == thunder.torch.mul:
                    (mul,) = c
                    mul_l, mul_r = mul.args
                    if mul_l is b.output:
                        other = mul_r
                    else:
                        other = mul_l
                    new_bsym = b.from_bsym(
                        sym=liger_swiglu_forward, output=mul.output, args=(b.args[0], other), subsymbols=[]
                    )
                    bsyms_to_skip.add(mul)
            new_computation_trace.bound_symbols.append(new_bsym)
        new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("constructed by FuseSwigLU"))
        return prologue_trace, new_computation_trace, epilogue_trace
[ ]:

Fused Linear and Cross Entropy

[10]:
def liger_fused_linear_cross_entropy_forward_meta(
    _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction="mean"
):
    logits = thunder.torch.linear(_input, weight, bias)
    loss = thunder.torch.cross_entropy(
        logits, target, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction
    )
    grad_input = TensorProxy(like=_input)
    grad_weight = TensorProxy(like=weight)
    grad_bias = None if bias is None else TensorProxy(like=bias)
    return loss, grad_input, grad_weight, grad_bias


liger_fused_linear_cross_entropy_forward = liger_ex.register_operator(
    "liger_fused_linear_cross_entropy_forward",
    fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_forward,
    like=liger_fused_linear_cross_entropy_forward_meta,
)


def liger_fused_linear_cross_entropy_backward_meta(grad_output, grad_input, grad_weight, grad_bias):
    return (
        TensorProxy(like=grad_input),
        TensorProxy(like=grad_weight),
        (TensorProxy(like=grad_bias) if grad_bias is not None else None),
    )


liger_fused_linear_cross_entropy_backward = liger_ex.register_operator(
    "liger_fused_linear_cross_entropy_backward",
    fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_backward,
    meta=liger_fused_linear_cross_entropy_backward_meta,
)


def liger_fused_linear_cross_entropy_grad_transform(
    _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction="mean"
):
    loss, grad_input_1, grad_weight_1, grad_bias_1 = liger_fused_linear_cross_entropy_forward(
        _input,
        weight,
        target,
        bias=bias,
        ignore_index=ignore_index,
        label_smoothing=label_smoothing,
        reduction=reduction,
    )
    grad_loss = get_grad(loss)
    grad_input, grad_weight, grad_bias = liger_fused_linear_cross_entropy_backward(
        grad_loss, grad_input_1, grad_weight_1, grad_bias_1
    )
    put_grads((_input, weight, target), (grad_input, grad_weight, grad_bias))
    return loss


liger_ex.register_implementation(
    liger_fused_linear_cross_entropy_forward,
    grad_transform=liger_fused_linear_cross_entropy_grad_transform,
    execution_transform=liger_fused_linear_cross_entropy_forward,
)


class FuseLinearCrossEntropyTransform(thunder.core.transform_common.Transform):
    def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
        _, consumers = thunder.core.utils.producers_and_consumers(computation_trace)
        new_computation_trace = thunder.core.trace.from_trace(computation_trace)
        bsyms_to_skip = set()
        for b in computation_trace.bound_symbols:
            if b in bsyms_to_skip:
                continue
            new_bsym = b
            if b.sym == thunder.torch.linear:
                c = consumers[b.output]
                if len(c) == 1 and c[0].sym == thunder.torch.cross_entropy:
                    (ce,) = c
                    assert not ce.kwargs
                    assert not b.kwargs
                    assert ce.args[0] is b.output
                    inp, weight, bias = b.args
                    _, targets, ce_weight, size_average, ignore_index, reduce, reduction, label_smoothing = ce.args
                    assert ce_weight is None
                    assert size_average is None
                    assert reduce is None
                    new_bsym = b.from_bsym(
                        sym=liger_fused_linear_cross_entropy_forward,
                        output=ce.output,
                        args=(inp, weight, targets, bias, ignore_index, label_smoothing, reduction),
                        subsymbols=[],
                    )
                    bsyms_to_skip.add(ce)
            new_computation_trace.bound_symbols.append(new_bsym)
        new_computation_trace.set_provenance(
            thunder.core.trace.TraceProvenance("constructed by FuseLinearCrossEntropy")
        )
        return prologue_trace, new_computation_trace, epilogue_trace
[11]:
def apply_eye_meta(x):
    return thunder.TensorProxy(like=x)


def apply_eye(mask):
    mask = mask | torch.eye(mask.shape[-1], dtype=torch.bool, device=mask.device)[None, None]
    return mask


t_apply_eye = liger_ex.register_operator("t_apply_eye", fn=apply_eye, meta=apply_eye_meta, replaces=apply_eye)


def apply_eye_grad_transform(x):
    return t_apply_eye(x)


liger_ex.register_implementation(
    t_apply_eye, execution_transform=apply_eye_grad_transform, grad_transform=apply_eye_grad_transform
)


class GPTForFineTuningLastToken(litgpt.model.GPT):
    def forward(self, idx: torch.Tensor, *, mask: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        mask = mask.bool()
        T = idx.size(1)
        if self.max_seq_length < T:
            raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")

        attn_mask = (
            litgpt.model.build_mask_cache(mask.shape[-1], mask.device).expand(4, -1, -1, -1) * mask[:, None, None, :]
        )
        attn_mask = apply_eye(attn_mask)

        cos = self.cos[:T]
        sin = self.sin[:T]
        x = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
        if self.config.scale_embeddings:
            x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)

        for block in self.transformer.h:
            x = block(x, cos, sin, attn_mask, None)

        # second to last prediction is the output
        x = x[:, -2]
        x = self.transformer.ln_f(x)
        x = self.lm_head(x)  # (b, t, vocab_size)
        if self.config.final_logit_softcapping is not None:
            x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
        loss = torch.nn.functional.cross_entropy(x, labels)
        return loss


cfg = litgpt.Config.from_name("Llama-3.2-1B", n_layer=1)
with device:
    m = GPTForFineTuningLastToken(cfg)
    m.max_seq_length = 1024
    inp = torch.ones(4, 32, dtype=torch.int64)
    mask = torch.ones(4, 32, dtype=torch.int64)
    labels = torch.ones(4, dtype=torch.int64)


jm = thunder.jit(
    m,
    executors=(liger_ex,),
    transforms=(
        MergeRopeTransform(),
        FuseSwigLUTransform(),
        FuseLinearCrossEntropyTransform(),
    ),
)
res = jm(inp, mask=mask, labels=labels)
ref = m(inp, mask=mask, labels=labels)

go = torch.randn_like(res)
(grad_res,) = torch.autograd.grad(res, jm.get_parameter("transformer.wte.weight"), go)
(grad_ref,) = torch.autograd.grad(ref, m.get_parameter("transformer.wte.weight"), go)

assert_close(res, ref)
assert_close(grad_res, grad_ref)

assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_swiglu_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_swiglu_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_fused_linear_cross_entropy_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(
    bsym.sym is liger_fused_linear_cross_entropy_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols
)

End to end example

adapted from a Liger-Kernel example.

Code below is

Copyright 2024 LinkedIn Corporation (BSD 2-CLAUSE LICENSE)

[12]:
if False:  # this example has additional dependencies, so we skip it in the CI
    import argparse
    import math
    import os
    from dataclasses import _MISSING_TYPE, dataclass
    import litgpt

    import datasets
    import lightning.pytorch as pl
    import torch
    import transformers
    from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy
    from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
    from torch.utils.data import DataLoader
    from trl import DataCollatorForCompletionOnlyLM
    import warnings

    warnings.simplefilter(action="ignore", category=FutureWarning)


    _RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
    QUESTION = "<Question>"
    CHOICES = "<Choices>"


    @dataclass
    class Args:
        model: str = "meta-llama/Llama-3.2-1B-Instruct"
        data: str = "cais/mmlu"
        output_dir: str = "mmlu_finetuning"
        max_length: int = 2048
        # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
        batch_size: int = 4
        lr: float = 6e-6
        weight_decay: float = 0.05
        warmup_ratio: float = 0.1
        seed: int = 42
        strategy: str = "auto"
        num_gpu: int = 1


    def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                # Linear warmup
                return float(current_step) / float(max(1, warmup_steps))
            else:
                # Cosine annealing
                progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
                return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))

        return lr_lambda


    def parse_args() -> Args:
        parser = argparse.ArgumentParser()
        for k, v in Args.__dataclass_fields__.items():
            parser.add_argument(f"--{k}", type=v.type, default=v.default)
        parsed = parser.parse_args([])
        return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})


    class LanguageModel(pl.LightningModule):
        def __init__(self, args: Args, tokenizer):
            super().__init__()
            self.args = args
            self.tokenizer = tokenizer
            self.model = None

        def configure_model(self):
            # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization
            if self.model is not None:
                return
            self.model = GPTForFineTuningLastToken.from_name(self.args.model.rsplit("/", 1)[-1]).to(torch.bfloat16)
            self.model.load_state_dict(litgpt.utils.lazy_load(f"checkpoints/{self.args.model}/lit_model.pth"))
            self.model = thunder.jit(
                self.model,
                executors=(liger_ex, *thunder.get_default_executors()),
                transforms=(MergeRopeTransform(), FuseSwigLUTransform(), FuseLinearCrossEntropyTransform()),
            )

        def forward(self, input_ids, attention_mask, labels=None, **kwargs):
            return self.model(idx=input_ids, mask=attention_mask, labels=labels, **kwargs)

        def training_step(self, batch):
            outputs = self.model(
                idx=batch["input_ids"],
                mask=batch["attention_mask"],
                labels=batch["labels"][:, -1],
            )
            loss = outputs
            self.log_dict(
                {"train_loss": loss},
                on_step=True,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                rank_zero_only=True,
                sync_dist=False,
            )
            return loss

        def validation_step(self, batch):
            outputs = self.model(
                idx=batch["input_ids"],
                mask=batch["attention_mask"],
                labels=batch["labels"][:, -1],
            )
            loss = outputs
            self.log_dict(
                {"val_loss": loss},
                on_step=True,
                on_epoch=True,
                prog_bar=True,
                logger=True,
                rank_zero_only=True,
                sync_dist=True,
            )
            return loss

        def configure_optimizers(self):
            optimizer = torch.optim.AdamW(
                self.parameters(),
                lr=self.args.lr,
                weight_decay=self.args.weight_decay,
                fused=True,
            )
            lr_lambda = warmup_cosine_schedule(
                warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio,
                total_steps=self.trainer.estimated_stepping_batches,
                min_lr=0,
            )
            lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
            }


    class DataModule(pl.LightningDataModule):
        def __init__(self, tokenizer, args: Args):
            super().__init__()
            self.train_dataset = None
            self.args = args
            self.tokenizer = tokenizer
            self.response_template_str = " <Answer>"
            response_prompt = tokenizer.encode(f"{self.response_template_str}", add_special_tokens=False)
            self.collator = DataCollatorForCompletionOnlyLM(
                tokenizer=tokenizer,
                response_template=response_prompt,
                pad_to_multiple_of=16,
            )

        def formatting_func(self, example):
            output_texts = []
            for i in range(len(example["question"])):
                choices = ""
                for j in range(len(example["choices"][i])):
                    choices += f"{j+1}. {example['choices'][i][j]}; "
                s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
                s += f"{QUESTION}{example['question'][i]} "
                s += f"{CHOICES}{choices} "
                s += f"{self.response_template_str}{example['answer'][i]}"
                output_texts.append(s)
            return output_texts

        def tokenize(self, example):
            outputs = self.tokenizer(
                self.formatting_func(example),
                truncation=True,
                padding=False,
                max_length=self.args.max_length,
            )
            return {
                "input_ids": outputs["input_ids"],
                "attention_mask": outputs["attention_mask"],
            }

        def setup(self, stage) -> None:
            if self.train_dataset is not None:
                return
            dataset = datasets.load_dataset(self.args.data, "auxiliary_train")
            flattened_data = [
                {
                    "answer": x["train"]["answer"],
                    "choices": x["train"]["choices"],
                    "question": x["train"]["question"],
                    "subject": x["train"]["subject"],
                }
                for x in dataset["train"]
            ][:32]
            dataset = datasets.Dataset.from_list(flattened_data)
            dataset = dataset.train_test_split(test_size=4, seed=self.args.seed)
            train_dataset, val_dataset = dataset["train"], dataset["test"]
            self.train_dataset = train_dataset.map(
                self.tokenize,
                remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),
                batched=True,
                batch_size=1,
                num_proc=4,
            )
            self.val_dataset = val_dataset.map(
                self.tokenize,
                remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),
                batched=True,
                batch_size=1,
                num_proc=4,
            )

        def train_dataloader(self):
            return DataLoader(
                self.train_dataset,
                batch_size=self.args.batch_size,
                collate_fn=self.collator,
            )

        def val_dataloader(self):
            return DataLoader(
                self.val_dataset,
                batch_size=self.args.batch_size,
                collate_fn=self.collator,
            )


    args = parse_args()
    pl.seed_everything(args.seed)
    os.makedirs(args.output_dir, exist_ok=True)

    if args.strategy == "fsdp":
        strategy = FSDPStrategy(
            auto_wrap_policy=layers,
            sharding_strategy="FULL_SHARD",
            backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
            sync_module_states=True,
            activation_checkpointing_policy=layers,
            mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
            forward_prefetch=True,
        )
        precision = None
    elif args.strategy == "deepspeed":
        strategy = DeepSpeedStrategy(stage=3)
        precision = "bf16-mixed"
    elif args.strategy == "ddp":
        strategy = "ddp"
        precision = "bf16-true"
    else:
        strategy = "auto"
        precision = "bf16-true"

    # This only works if you have a snapshot to work from.
    trainer = pl.Trainer(
        accelerator="cuda",
        strategy=strategy,
        devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu,
        default_root_dir=args.output_dir,
        log_every_n_steps=1,
        max_epochs=1,
        precision=precision,
    )

    tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side="left", truncation_side="left")
    tokenizer.pad_token = tokenizer.eos_token
    data_module = DataModule(
        tokenizer=tokenizer,
        args=args,
    )

    model = LanguageModel(args=args, tokenizer=tokenizer)
    trainer.fit(model, datamodule=data_module)
[ ]: