{ "cells": [ { "cell_type": "markdown", "id": "2055f09e-ea78-4726-9a48-65115000b140", "metadata": {}, "source": [ "# Thunder bindings for Liger operators\n", "\n", "In this notebook we explore Thunder Bindings for Liger Operators.\n", "\n", "It is based on [Episode 10 of the Thunder Sessions podcast](https://www.youtube.com/watch?v=3H_aw6o-d9c&list=PLaMu-SDt_RB7ImARcTT_Wjypwx2vBIBen&index=10).\n", "\n", "Let's import things." ] }, { "cell_type": "code", "execution_count": 1, "id": "8f4102a8-f68b-4012-bd5b-a1d3daeab367", "metadata": {}, "outputs": [], "source": [ "from collections.abc import Sequence\n", "import math\n", "\n", "import torch\n", "from torch.testing import assert_close\n", "import litgpt\n", "import thunder\n", "from thunder.core.proxies import TensorProxy, AnyProxy\n", "from thunder.core.transforms import get_grad, put_grads\n", "from thunder.torch import TensorLike\n", "import thunder.extend\n", "\n", "import liger_kernel.ops.rms_norm\n", "import liger_kernel.ops.rope\n", "import liger_kernel.ops.swiglu\n", "import liger_kernel.ops.geglu # TODO\n", "import liger_kernel.ops.cross_entropy # TODO\n", "import liger_kernel.ops.fused_linear_cross_entropy\n", "\n", "device = torch.device(\"cuda\")" ] }, { "cell_type": "markdown", "id": "6b44643e-a92c-4398-861f-793cec2e7414", "metadata": {}, "source": [ "We define and register an executor." ] }, { "cell_type": "code", "execution_count": 2, "id": "c5232472-a67c-4650-abf9-370e4692e93d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "thunder.extend.OperatorExecutor('liger')" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "liger_ex = thunder.extend.OperatorExecutor(\"liger\", version=\"0.1\")\n", "thunder.extend.register_executor(liger_ex)" ] }, { "cell_type": "markdown", "id": "b207657e-a40c-4cda-a2d6-3f0e11ae4949", "metadata": {}, "source": [ "## RMS Norm\n", "\n", "The first thing to fuse is RMS Norm.\n", "\n", "After that, Liger's implementation is a drop-in replacement. We define operators for forward and backward and then a gradient and execution rule.\n", "\n", "We register these as an implementation for the rms_norm operand that we divert the PyTorch function to." ] }, { "cell_type": "code", "execution_count": 3, "id": "4411cc5f-5535-48e2-ba7c-00b984f15ad2", "metadata": {}, "outputs": [], "source": [ "# A tiny detail here is that PyTorch gained a `rms_norm` function somewhat\n", "# recently and we need to tell LitGPT to use it.\n", "\n", "\n", "def RMSNorm_forward(self, x):\n", " return torch.nn.functional.rms_norm(x, self.weight.shape, self.weight, self.eps)\n", "\n", "\n", "litgpt.model.RMSNorm.forward = RMSNorm_forward" ] }, { "cell_type": "code", "execution_count": 4, "id": "77757535-b292-4a96-a6a3-c0e7f05d70ea", "metadata": {}, "outputs": [], "source": [ "import functools\n", "\n", "prod = lambda *args: functools.reduce(lambda x, y: x * y, args)" ] }, { "cell_type": "code", "execution_count": 5, "id": "f542954c-aba1-4523-9a7d-436348a6af96", "metadata": {}, "outputs": [], "source": [ "# ******************************* RMS NORM *******************************\n", "import functools\n", "\n", "\n", "def liger_rms_norm_forward_meta(X, W, eps, offset, casting_mode):\n", " *n_rows, n_cols = X.shape\n", " n_rows = prod(*n_rows)\n", " # RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode\n", " rstd_dtype = (\n", " thunder.dtypes.float32\n", " if casting_mode\n", " in (liger_kernel.ops.rms_norm._CASTING_MODE_LLAMA.value, liger_kernel.ops.rms_norm._CASTING_MODE_GEMMA.value)\n", " else X.dtype\n", " )\n", " Y = TensorProxy(like=X)\n", " RSTD = TensorProxy(like=X, shape=(n_rows,), dtype=rstd_dtype)\n", " BLOCK_SIZE, num_warps = liger_kernel.ops.rms_norm.calculate_settings(n_cols)\n", " return Y, TensorProxy(like=X, shape=(n_rows, n_cols)), RSTD, BLOCK_SIZE, num_warps, casting_mode\n", "\n", "\n", "liger_rms_norm_forward = liger_ex.register_operator(\n", " \"liger_rms_norm_forward\", meta=liger_rms_norm_forward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_forward\n", ")\n", "\n", "\n", "def liger_rms_norm_backward_meta(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):\n", " return TensorProxy(like=X), TensorProxy(like=W)\n", "\n", "\n", "liger_rms_norm_backward = liger_ex.register_operator(\n", " \"liger_rms_norm_backward\", meta=liger_rms_norm_backward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_backward\n", ")\n", "\n", "\n", "def rms_norm_meta(x, shape, w, eps):\n", " return thunder.TensorProxy(like=x)\n", "\n", "\n", "rms_norm = liger_ex.register_operator(\n", " \"rms_norm\", meta=rms_norm_meta, fn=torch.nn.functional.rms_norm, replaces=torch.nn.functional.rms_norm\n", ")\n", "\n", "\n", "def rms_norm_grad_transform(x, shape, weight, eps):\n", " Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = liger_rms_norm_forward(\n", " x, weight, eps, offset=0.0, casting_mode=\"llama\"\n", " )\n", " dY = get_grad(Y)\n", " dX, dW = liger_rms_norm_backward(\n", " dY, X, weight, RSTD, offset=0.0, casting_mode=\"llama\", BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps\n", " )\n", " dX = dX.view(*x.shape)\n", " put_grads((x, weight), (dX, dW))\n", " return Y\n", "\n", "\n", "def rms_norm_execution_transform(x, weight, eps):\n", " Y, *_ = liger_rms_norm_forward(x, weight, eps, offset=0.0, casting_mode=\"llama\")\n", " return Y\n", "\n", "\n", "liger_ex.register_implementation(\n", " rms_norm, execution_transform=rms_norm_execution_transform, grad_transform=rms_norm_grad_transform\n", ")" ] }, { "cell_type": "markdown", "id": "0ace1ad2-25f4-4a20-ad39-1f030bca0f38", "metadata": {}, "source": [ "### Testing RMS Norm\n", "\n", "Let's test." ] }, { "cell_type": "code", "execution_count": 6, "id": "56f1d6ee-a4ac-42f1-9d65-2c774cda4d18", "metadata": {}, "outputs": [], "source": [ "hidden_size = 64\n", "\n", "example_input = torch.randn(32, 10, hidden_size, device=device, requires_grad=True)\n", "\n", "with device:\n", " model = litgpt.model.RMSNorm(hidden_size)\n", "thunder_model = thunder.jit(model, executors=[liger_ex])\n", "ref = model(example_input.clone())\n", "res = thunder_model(example_input.clone())\n", "go = torch.randn_like(ref)\n", "grad_ref, grad_ref_weight = torch.autograd.grad(ref, (example_input, model.weight), go)\n", "grad_res, grad_res_weight = torch.autograd.grad(res, (example_input, model.weight), go)\n", "\n", "\n", "assert liger_rms_norm_forward in {bsym.sym for bsym in thunder.last_traces(thunder_model)[-1].bound_symbols}\n", "assert liger_rms_norm_backward in {bsym.sym for bsym in thunder.last_backward_traces(thunder_model)[-1].bound_symbols}\n", "\n", "assert_close(ref, res)\n", "assert_close(grad_ref, grad_res)\n", "assert_close(grad_ref_weight, grad_res_weight)" ] }, { "cell_type": "code", "execution_count": null, "id": "60dcb262-2255-4c17-b64f-f38e8ebd8e33", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "71c49c38-ce84-4727-9f57-8b42b908349f", "metadata": {}, "source": [ "# RoPE\n", "\n", "Next is the RoPE implementation. Liger does both rope applications to query and key in one kernel whereas\n", "LitGPT uses two. So we define not only forward and backward and a symbol to capture the litgpt version,\n", "but also a small transform fusing the two `apply_rope` calls to one `liger_rope`." ] }, { "cell_type": "code", "execution_count": null, "id": "32cd98f0-a36f-4e01-8ae4-6e36adf2699b", "metadata": {}, "outputs": [], "source": [ "def liger_rope_forward_meta(q, k, cos, sin):\n", " return TensorProxy(like=q), TensorProxy(like=k), cos, sin\n", "\n", "\n", "liger_rope_forward = liger_ex.register_operator(\n", " \"liger_rope_forward\",\n", " meta=liger_rope_forward_meta,\n", " fn=liger_kernel.ops.rope.rope_forward,\n", ")\n", "\n", "\n", "def liger_rope_backward_meta(dq, dk, cos, sin):\n", " return TensorLike(like=dq), TensorLike(like=dk)\n", "\n", "\n", "liger_rope_backward = liger_ex.register_operator(\n", " \"liger_rope_backward\",\n", " meta=liger_rope_backward_meta,\n", " fn=liger_kernel.ops.rope.rope_backward,\n", ")\n", "\n", "\n", "def liger_rope_grad_transform(q, k, cos, sin):\n", " q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)\n", " q_out_grad = get_grad(q_out)\n", " k_out_grad = get_grad(k_out)\n", " dq, dk = liger_rope_backward(q_out_grad, k_out_grad, cos, sin)\n", " put_grads((q, k), (dq, dk))\n", " return q_out, k_out\n", "\n", "\n", "def liger_rope_execution_transform(q, k, cos, sin):\n", " q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)\n", " return q_out, k_out\n", "\n", "\n", "def liger_rope_impl(q, k, cos, sin):\n", " qr, kr, _, _ = liger_rope_forward(q, k, cos, sin)\n", " return qr, kr\n", "\n", "\n", "liger_rope = liger_ex.register_operator(\"liger_rope\", fn=liger_rope_impl, like=liger_rope_impl)\n", "\n", "liger_ex.register_implementation(\n", " liger_rope,\n", " execution_transform=liger_rope_execution_transform,\n", " grad_transform=liger_rope_grad_transform,\n", ")\n", "\n", "\n", "def litgpt_apply_rope_meta(x, cos, sin):\n", " return TensorProxy(like=x)\n", "\n", "\n", "litgpt_apply_rope = liger_ex.register_operator(\n", " \"litgpt_apply_rope\", fn=litgpt.model.apply_rope, meta=litgpt_apply_rope_meta, replaces=litgpt.model.apply_rope\n", ")\n", "\n", "\n", "class MergeRopeTransform(thunder.core.transform_common.Transform):\n", " def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):\n", " new_compute_trace = thunder.core.trace.from_trace(compute_trace)\n", " bound_symbols = compute_trace.bound_symbols[:]\n", " while bound_symbols:\n", " bsym = bound_symbols.pop(0)\n", " if bsym.sym == litgpt_apply_rope:\n", " while bound_symbols:\n", " bsym2 = bound_symbols.pop(0)\n", " assert not any(o is bsym.output for o in bsym2.flat_outs)\n", " if bsym2.sym == litgpt_apply_rope:\n", " break\n", " new_compute_trace.bound_symbols.append(bsym2.from_bsym())\n", " assert bsym2.sym == litgpt_apply_rope\n", "\n", " output = (bsym.output, bsym2.output)\n", " args = (bsym.args[0], bsym2.args[0], *bsym.args[1:])\n", "\n", " new_compute_trace.bound_symbols.append(bsym.from_bsym(args=args, output=output, sym=liger_rope))\n", " else:\n", " new_compute_trace.bound_symbols.append(bsym.from_bsym())\n", " new_compute_trace.set_provenance(thunder.core.trace.TraceProvenance(self.__class__))\n", " return prologue_trace, new_compute_trace, epilogue_trace" ] }, { "cell_type": "markdown", "id": "44187b29-c101-41f0-a811-4c9f29757c81", "metadata": {}, "source": [ "# Test\n", "\n", "We test with a scaled-down Llama." ] }, { "cell_type": "code", "execution_count": 8, "id": "b8fd2563-7b89-487d-8fdc-21661380e2c0", "metadata": {}, "outputs": [], "source": [ "cfg = litgpt.Config.from_name(\"Llama-3.2-1B\", n_layer=1)\n", "with device:\n", " m = litgpt.GPT(cfg)\n", " m.max_seq_length = 1024\n", " m.set_kv_cache(1)\n", " inp = torch.arange(1, 6, dtype=torch.int64)[None]\n", " inp_pos = torch.arange(1, 6, dtype=torch.int64)\n", "\n", "\n", "jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))\n", "res = jm(inp, inp_pos)\n", "\n", "go = torch.randn_like(res)\n", "(grad_res,) = torch.autograd.grad(res, jm.get_parameter(\"transformer.wte.weight\"), go)\n", "ref = m(inp, inp_pos)\n", "(grad_ref,) = torch.autograd.grad(ref, m.get_parameter(\"transformer.wte.weight\"), go)\n", "\n", "assert_close(res, ref)\n", "assert_close(grad_res, grad_ref)\n", "\n", "assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)" ] }, { "cell_type": "markdown", "id": "e341460b-71d4-4e83-b67c-14bdec7d8026", "metadata": {}, "source": [ "## SwiGLU\n" ] }, { "cell_type": "code", "execution_count": 9, "id": "5e26e9ec-3eb0-46e8-90d1-8cd84d5dd1b7", "metadata": {}, "outputs": [], "source": [ "def liger_swiglu_forward_meta(a, b):\n", " return TensorProxy(like=a)\n", "\n", "\n", "def liger_swiglu_forward_impl(a, b):\n", " _, _, res = liger_kernel.ops.swiglu.swiglu_forward(a, b)\n", " return res\n", "\n", "\n", "liger_swiglu_forward = liger_ex.register_operator(\n", " \"liger_swiglu_forward\",\n", " meta=liger_swiglu_forward_meta,\n", " fn=liger_swiglu_forward_impl,\n", ")\n", "\n", "\n", "def liger_swiglu_backward_meta(a, b, grad_res):\n", " return TensorProxy(like=a), TensorProxy(like=b)\n", "\n", "\n", "liger_swiglu_backward = liger_ex.register_operator(\n", " \"liger_swiglu_backward\",\n", " meta=liger_swiglu_backward_meta,\n", " fn=liger_kernel.ops.swiglu.swiglu_backward,\n", ")\n", "\n", "\n", "def liger_swiglu_gradient_transform(a, b):\n", " res = liger_swiglu_forward(a, b)\n", " grad_res = get_grad(res)\n", " grad_a, grad_b = liger_swiglu_backward(a, b, grad_res)\n", " put_grads((a, b), (grad_a, grad_b))\n", " return res\n", "\n", "\n", "liger_ex.register_implementation(\n", " liger_swiglu_forward, grad_transform=liger_swiglu_gradient_transform, execution_transform=liger_swiglu_forward\n", ")\n", "\n", "\n", "class FuseSwigLUTransform(thunder.core.transform_common.Transform):\n", " def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):\n", " _, consumers = thunder.core.utils.producers_and_consumers(computation_trace)\n", " new_computation_trace = thunder.core.trace.from_trace(computation_trace)\n", " bsyms_to_skip = set()\n", " for b in computation_trace.bound_symbols:\n", " if b in bsyms_to_skip:\n", " continue\n", " new_bsym = b\n", " if b.sym == thunder.torch.silu:\n", " c = consumers[b.output]\n", " if len(c) == 1 and c[0].sym == thunder.torch.mul:\n", " (mul,) = c\n", " mul_l, mul_r = mul.args\n", " if mul_l is b.output:\n", " other = mul_r\n", " else:\n", " other = mul_l\n", " new_bsym = b.from_bsym(\n", " sym=liger_swiglu_forward, output=mul.output, args=(b.args[0], other), subsymbols=[]\n", " )\n", " bsyms_to_skip.add(mul)\n", " new_computation_trace.bound_symbols.append(new_bsym)\n", " new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance(\"constructed by FuseSwigLU\"))\n", " return prologue_trace, new_computation_trace, epilogue_trace" ] }, { "cell_type": "code", "execution_count": null, "id": "c004b1f6-9756-44ae-88f3-d088f0c838c6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "efb5dcb5-a411-4832-9cac-a868dc3142b0", "metadata": {}, "source": [ "## Fused Linear and Cross Entropy" ] }, { "cell_type": "code", "execution_count": 10, "id": "55ff0b33-99ec-4de1-a3c0-7a78ebdf83c4", "metadata": {}, "outputs": [], "source": [ "def liger_fused_linear_cross_entropy_forward_meta(\n", " _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction=\"mean\"\n", "):\n", " logits = thunder.torch.linear(_input, weight, bias)\n", " loss = thunder.torch.cross_entropy(\n", " logits, target, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction\n", " )\n", " grad_input = TensorProxy(like=_input)\n", " grad_weight = TensorProxy(like=weight)\n", " grad_bias = None if bias is None else TensorProxy(like=bias)\n", " return loss, grad_input, grad_weight, grad_bias\n", "\n", "\n", "liger_fused_linear_cross_entropy_forward = liger_ex.register_operator(\n", " \"liger_fused_linear_cross_entropy_forward\",\n", " fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_forward,\n", " like=liger_fused_linear_cross_entropy_forward_meta,\n", ")\n", "\n", "\n", "def liger_fused_linear_cross_entropy_backward_meta(grad_output, grad_input, grad_weight, grad_bias):\n", " return (\n", " TensorProxy(like=grad_input),\n", " TensorProxy(like=grad_weight),\n", " (TensorProxy(like=grad_bias) if grad_bias is not None else None),\n", " )\n", "\n", "\n", "liger_fused_linear_cross_entropy_backward = liger_ex.register_operator(\n", " \"liger_fused_linear_cross_entropy_backward\",\n", " fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_backward,\n", " meta=liger_fused_linear_cross_entropy_backward_meta,\n", ")\n", "\n", "\n", "def liger_fused_linear_cross_entropy_grad_transform(\n", " _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction=\"mean\"\n", "):\n", " loss, grad_input_1, grad_weight_1, grad_bias_1 = liger_fused_linear_cross_entropy_forward(\n", " _input,\n", " weight,\n", " target,\n", " bias=bias,\n", " ignore_index=ignore_index,\n", " label_smoothing=label_smoothing,\n", " reduction=reduction,\n", " )\n", " grad_loss = get_grad(loss)\n", " grad_input, grad_weight, grad_bias = liger_fused_linear_cross_entropy_backward(\n", " grad_loss, grad_input_1, grad_weight_1, grad_bias_1\n", " )\n", " put_grads((_input, weight, target), (grad_input, grad_weight, grad_bias))\n", " return loss\n", "\n", "\n", "liger_ex.register_implementation(\n", " liger_fused_linear_cross_entropy_forward,\n", " grad_transform=liger_fused_linear_cross_entropy_grad_transform,\n", " execution_transform=liger_fused_linear_cross_entropy_forward,\n", ")\n", "\n", "\n", "class FuseLinearCrossEntropyTransform(thunder.core.transform_common.Transform):\n", " def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):\n", " _, consumers = thunder.core.utils.producers_and_consumers(computation_trace)\n", " new_computation_trace = thunder.core.trace.from_trace(computation_trace)\n", " bsyms_to_skip = set()\n", " for b in computation_trace.bound_symbols:\n", " if b in bsyms_to_skip:\n", " continue\n", " new_bsym = b\n", " if b.sym == thunder.torch.linear:\n", " c = consumers[b.output]\n", " if len(c) == 1 and c[0].sym == thunder.torch.cross_entropy:\n", " (ce,) = c\n", " assert not ce.kwargs\n", " assert not b.kwargs\n", " assert ce.args[0] is b.output\n", " inp, weight, bias = b.args\n", " _, targets, ce_weight, size_average, ignore_index, reduce, reduction, label_smoothing = ce.args\n", " assert ce_weight is None\n", " assert size_average is None\n", " assert reduce is None\n", " new_bsym = b.from_bsym(\n", " sym=liger_fused_linear_cross_entropy_forward,\n", " output=ce.output,\n", " args=(inp, weight, targets, bias, ignore_index, label_smoothing, reduction),\n", " subsymbols=[],\n", " )\n", " bsyms_to_skip.add(ce)\n", " new_computation_trace.bound_symbols.append(new_bsym)\n", " new_computation_trace.set_provenance(\n", " thunder.core.trace.TraceProvenance(\"constructed by FuseLinearCrossEntropy\")\n", " )\n", " return prologue_trace, new_computation_trace, epilogue_trace" ] }, { "cell_type": "code", "execution_count": 11, "id": "89431922-f074-4825-a6b6-7365abe5b0b4", "metadata": { "scrolled": true }, "outputs": [], "source": [ "def apply_eye_meta(x):\n", " return thunder.TensorProxy(like=x)\n", "\n", "\n", "def apply_eye(mask):\n", " mask = mask | torch.eye(mask.shape[-1], dtype=torch.bool, device=mask.device)[None, None]\n", " return mask\n", "\n", "\n", "t_apply_eye = liger_ex.register_operator(\"t_apply_eye\", fn=apply_eye, meta=apply_eye_meta, replaces=apply_eye)\n", "\n", "\n", "def apply_eye_grad_transform(x):\n", " return t_apply_eye(x)\n", "\n", "\n", "liger_ex.register_implementation(\n", " t_apply_eye, execution_transform=apply_eye_grad_transform, grad_transform=apply_eye_grad_transform\n", ")\n", "\n", "\n", "class GPTForFineTuningLastToken(litgpt.model.GPT):\n", " def forward(self, idx: torch.Tensor, *, mask: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:\n", " mask = mask.bool()\n", " T = idx.size(1)\n", " if self.max_seq_length < T:\n", " raise ValueError(f\"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.\")\n", "\n", " attn_mask = (\n", " litgpt.model.build_mask_cache(mask.shape[-1], mask.device).expand(4, -1, -1, -1) * mask[:, None, None, :]\n", " )\n", " attn_mask = apply_eye(attn_mask)\n", "\n", " cos = self.cos[:T]\n", " sin = self.sin[:T]\n", " x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)\n", " if self.config.scale_embeddings:\n", " x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)\n", "\n", " for block in self.transformer.h:\n", " x = block(x, cos, sin, attn_mask, None)\n", "\n", " # second to last prediction is the output\n", " x = x[:, -2]\n", " x = self.transformer.ln_f(x)\n", " x = self.lm_head(x) # (b, t, vocab_size)\n", " if self.config.final_logit_softcapping is not None:\n", " x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping\n", " loss = torch.nn.functional.cross_entropy(x, labels)\n", " return loss\n", "\n", "\n", "cfg = litgpt.Config.from_name(\"Llama-3.2-1B\", n_layer=1)\n", "with device:\n", " m = GPTForFineTuningLastToken(cfg)\n", " m.max_seq_length = 1024\n", " inp = torch.ones(4, 32, dtype=torch.int64)\n", " mask = torch.ones(4, 32, dtype=torch.int64)\n", " labels = torch.ones(4, dtype=torch.int64)\n", "\n", "\n", "jm = thunder.jit(\n", " m,\n", " executors=(liger_ex,),\n", " transforms=(\n", " MergeRopeTransform(),\n", " FuseSwigLUTransform(),\n", " FuseLinearCrossEntropyTransform(),\n", " ),\n", ")\n", "res = jm(inp, mask=mask, labels=labels)\n", "ref = m(inp, mask=mask, labels=labels)\n", "\n", "go = torch.randn_like(res)\n", "(grad_res,) = torch.autograd.grad(res, jm.get_parameter(\"transformer.wte.weight\"), go)\n", "(grad_ref,) = torch.autograd.grad(ref, m.get_parameter(\"transformer.wte.weight\"), go)\n", "\n", "assert_close(res, ref)\n", "assert_close(grad_res, grad_ref)\n", "\n", "assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_swiglu_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_swiglu_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)\n", "assert any(bsym.sym is liger_fused_linear_cross_entropy_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)\n", "assert any(\n", " bsym.sym is liger_fused_linear_cross_entropy_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols\n", ")" ] }, { "cell_type": "markdown", "id": "2812d18c-c6a6-4d02-ba12-7a8002efc0e5", "metadata": {}, "source": [ "# End to end example\n", "\n", "adapted from a [Liger-Kernel example](https://github.com/linkedin/Liger-Kernel/blob/de12602d858a6e83aaacc56e5cb64ab218c75a0a/examples/lightning/training.py).\n", "\n", "Code below is\n", "\n", "Copyright 2024 LinkedIn Corporation ([BSD 2-CLAUSE LICENSE](https://github.com/linkedin/Liger-Kernel/blob/de12602d858a6e83aaacc56e5cb64ab218c75a0a/LICENSE))\n" ] }, { "cell_type": "code", "execution_count": 12, "id": "420850af-878a-415a-bacd-a0d3258a0cc3", "metadata": {}, "outputs": [], "source": [ "if False: # this example has additional dependencies, so we skip it in the CI\n", " import argparse\n", " import math\n", " import os\n", " from dataclasses import _MISSING_TYPE, dataclass\n", " import litgpt\n", " \n", " import datasets\n", " import lightning.pytorch as pl\n", " import torch\n", " import transformers\n", " from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy\n", " from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision\n", " from torch.utils.data import DataLoader\n", " from trl import DataCollatorForCompletionOnlyLM\n", " import warnings\n", " \n", " warnings.simplefilter(action=\"ignore\", category=FutureWarning)\n", " \n", " \n", " _RETAIN_COLUMNS = {\"input_ids\", \"attention_mask\", \"labels\"}\n", " QUESTION = \"\"\n", " CHOICES = \"\"\n", " \n", " \n", " @dataclass\n", " class Args:\n", " model: str = \"meta-llama/Llama-3.2-1B-Instruct\"\n", " data: str = \"cais/mmlu\"\n", " output_dir: str = \"mmlu_finetuning\"\n", " max_length: int = 2048\n", " # for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G\n", " batch_size: int = 4\n", " lr: float = 6e-6\n", " weight_decay: float = 0.05\n", " warmup_ratio: float = 0.1\n", " seed: int = 42\n", " strategy: str = \"auto\"\n", " num_gpu: int = 1\n", " \n", " \n", " def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):\n", " def lr_lambda(current_step):\n", " if current_step < warmup_steps:\n", " # Linear warmup\n", " return float(current_step) / float(max(1, warmup_steps))\n", " else:\n", " # Cosine annealing\n", " progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))\n", " return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))\n", " \n", " return lr_lambda\n", " \n", " \n", " def parse_args() -> Args:\n", " parser = argparse.ArgumentParser()\n", " for k, v in Args.__dataclass_fields__.items():\n", " parser.add_argument(f\"--{k}\", type=v.type, default=v.default)\n", " parsed = parser.parse_args([])\n", " return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})\n", " \n", " \n", " class LanguageModel(pl.LightningModule):\n", " def __init__(self, args: Args, tokenizer):\n", " super().__init__()\n", " self.args = args\n", " self.tokenizer = tokenizer\n", " self.model = None\n", " \n", " def configure_model(self):\n", " # https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization\n", " if self.model is not None:\n", " return\n", " self.model = GPTForFineTuningLastToken.from_name(self.args.model.rsplit(\"/\", 1)[-1]).to(torch.bfloat16)\n", " self.model.load_state_dict(litgpt.utils.lazy_load(f\"checkpoints/{self.args.model}/lit_model.pth\"))\n", " self.model = thunder.jit(\n", " self.model,\n", " executors=(liger_ex, *thunder.get_default_executors()),\n", " transforms=(MergeRopeTransform(), FuseSwigLUTransform(), FuseLinearCrossEntropyTransform()),\n", " )\n", " \n", " def forward(self, input_ids, attention_mask, labels=None, **kwargs):\n", " return self.model(idx=input_ids, mask=attention_mask, labels=labels, **kwargs)\n", " \n", " def training_step(self, batch):\n", " outputs = self.model(\n", " idx=batch[\"input_ids\"],\n", " mask=batch[\"attention_mask\"],\n", " labels=batch[\"labels\"][:, -1],\n", " )\n", " loss = outputs\n", " self.log_dict(\n", " {\"train_loss\": loss},\n", " on_step=True,\n", " on_epoch=True,\n", " prog_bar=True,\n", " logger=True,\n", " rank_zero_only=True,\n", " sync_dist=False,\n", " )\n", " return loss\n", " \n", " def validation_step(self, batch):\n", " outputs = self.model(\n", " idx=batch[\"input_ids\"],\n", " mask=batch[\"attention_mask\"],\n", " labels=batch[\"labels\"][:, -1],\n", " )\n", " loss = outputs\n", " self.log_dict(\n", " {\"val_loss\": loss},\n", " on_step=True,\n", " on_epoch=True,\n", " prog_bar=True,\n", " logger=True,\n", " rank_zero_only=True,\n", " sync_dist=True,\n", " )\n", " return loss\n", " \n", " def configure_optimizers(self):\n", " optimizer = torch.optim.AdamW(\n", " self.parameters(),\n", " lr=self.args.lr,\n", " weight_decay=self.args.weight_decay,\n", " fused=True,\n", " )\n", " lr_lambda = warmup_cosine_schedule(\n", " warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio,\n", " total_steps=self.trainer.estimated_stepping_batches,\n", " min_lr=0,\n", " )\n", " lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n", " return {\n", " \"optimizer\": optimizer,\n", " \"lr_scheduler\": {\"scheduler\": lr_scheduler, \"interval\": \"step\"},\n", " }\n", " \n", " \n", " class DataModule(pl.LightningDataModule):\n", " def __init__(self, tokenizer, args: Args):\n", " super().__init__()\n", " self.train_dataset = None\n", " self.args = args\n", " self.tokenizer = tokenizer\n", " self.response_template_str = \" \"\n", " response_prompt = tokenizer.encode(f\"{self.response_template_str}\", add_special_tokens=False)\n", " self.collator = DataCollatorForCompletionOnlyLM(\n", " tokenizer=tokenizer,\n", " response_template=response_prompt,\n", " pad_to_multiple_of=16,\n", " )\n", " \n", " def formatting_func(self, example):\n", " output_texts = []\n", " for i in range(len(example[\"question\"])):\n", " choices = \"\"\n", " for j in range(len(example[\"choices\"][i])):\n", " choices += f\"{j+1}. {example['choices'][i][j]}; \"\n", " s = \"Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. \"\n", " s += f\"{QUESTION}{example['question'][i]} \"\n", " s += f\"{CHOICES}{choices} \"\n", " s += f\"{self.response_template_str}{example['answer'][i]}\"\n", " output_texts.append(s)\n", " return output_texts\n", " \n", " def tokenize(self, example):\n", " outputs = self.tokenizer(\n", " self.formatting_func(example),\n", " truncation=True,\n", " padding=False,\n", " max_length=self.args.max_length,\n", " )\n", " return {\n", " \"input_ids\": outputs[\"input_ids\"],\n", " \"attention_mask\": outputs[\"attention_mask\"],\n", " }\n", " \n", " def setup(self, stage) -> None:\n", " if self.train_dataset is not None:\n", " return\n", " dataset = datasets.load_dataset(self.args.data, \"auxiliary_train\")\n", " flattened_data = [\n", " {\n", " \"answer\": x[\"train\"][\"answer\"],\n", " \"choices\": x[\"train\"][\"choices\"],\n", " \"question\": x[\"train\"][\"question\"],\n", " \"subject\": x[\"train\"][\"subject\"],\n", " }\n", " for x in dataset[\"train\"]\n", " ][:32]\n", " dataset = datasets.Dataset.from_list(flattened_data)\n", " dataset = dataset.train_test_split(test_size=4, seed=self.args.seed)\n", " train_dataset, val_dataset = dataset[\"train\"], dataset[\"test\"]\n", " self.train_dataset = train_dataset.map(\n", " self.tokenize,\n", " remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),\n", " batched=True,\n", " batch_size=1,\n", " num_proc=4,\n", " )\n", " self.val_dataset = val_dataset.map(\n", " self.tokenize,\n", " remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),\n", " batched=True,\n", " batch_size=1,\n", " num_proc=4,\n", " )\n", " \n", " def train_dataloader(self):\n", " return DataLoader(\n", " self.train_dataset,\n", " batch_size=self.args.batch_size,\n", " collate_fn=self.collator,\n", " )\n", " \n", " def val_dataloader(self):\n", " return DataLoader(\n", " self.val_dataset,\n", " batch_size=self.args.batch_size,\n", " collate_fn=self.collator,\n", " )\n", " \n", " \n", " args = parse_args()\n", " pl.seed_everything(args.seed)\n", " os.makedirs(args.output_dir, exist_ok=True)\n", " \n", " if args.strategy == \"fsdp\":\n", " strategy = FSDPStrategy(\n", " auto_wrap_policy=layers,\n", " sharding_strategy=\"FULL_SHARD\",\n", " backward_prefetch=BackwardPrefetch.BACKWARD_PRE,\n", " sync_module_states=True,\n", " activation_checkpointing_policy=layers,\n", " mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),\n", " forward_prefetch=True,\n", " )\n", " precision = None\n", " elif args.strategy == \"deepspeed\":\n", " strategy = DeepSpeedStrategy(stage=3)\n", " precision = \"bf16-mixed\"\n", " elif args.strategy == \"ddp\":\n", " strategy = \"ddp\"\n", " precision = \"bf16-true\"\n", " else:\n", " strategy = \"auto\"\n", " precision = \"bf16-true\"\n", "\n", " # This only works if you have a snapshot to work from.\n", " trainer = pl.Trainer(\n", " accelerator=\"cuda\",\n", " strategy=strategy,\n", " devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu,\n", " default_root_dir=args.output_dir,\n", " log_every_n_steps=1,\n", " max_epochs=1,\n", " precision=precision,\n", " )\n", "\n", " tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side=\"left\", truncation_side=\"left\")\n", " tokenizer.pad_token = tokenizer.eos_token\n", " data_module = DataModule(\n", " tokenizer=tokenizer,\n", " args=args,\n", " )\n", "\n", " model = LanguageModel(args=args, tokenizer=tokenizer)\n", " trainer.fit(model, datamodule=data_module)" ] }, { "cell_type": "code", "execution_count": null, "id": "434dbcf7-1ed3-4669-a90b-12044909be44", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }