Thunder bindings for Liger operators
In this notebook we explore Thunder Bindings for Liger Operators.
It is based on Episode 10 of the Thunder Sessions podcast.
Let’s import things.
[1]:
from collections.abc import Sequence
import math
import torch
from torch.testing import assert_close
import litgpt
import thunder
from thunder.core.proxies import TensorProxy, AnyProxy
from thunder.core.transforms import get_grad, put_grads
from thunder.torch import TensorLike
import thunder.extend
import liger_kernel.ops.rms_norm
import liger_kernel.ops.rope
import liger_kernel.ops.swiglu
import liger_kernel.ops.geglu # TODO
import liger_kernel.ops.cross_entropy # TODO
import liger_kernel.ops.fused_linear_cross_entropy
device = torch.device("cuda")
We define and register an executor.
[2]:
liger_ex = thunder.extend.OperatorExecutor("liger", version="0.1")
thunder.extend.register_executor(liger_ex)
[2]:
thunder.extend.OperatorExecutor('liger')
RMS Norm
The first thing to fuse is RMS Norm.
After that, Liger’s implementation is a drop-in replacement. We define operators for forward and backward and then a gradient and execution rule.
We register these as an implementation for the rms_norm operand that we divert the PyTorch function to.
[3]:
# A tiny detail here is that PyTorch gained a `rms_norm` function somewhat
# recently and we need to tell LitGPT to use it.
def RMSNorm_forward(self, x):
return torch.nn.functional.rms_norm(x, self.weight.shape, self.weight, self.eps)
litgpt.model.RMSNorm.forward = RMSNorm_forward
[4]:
import functools
prod = lambda *args: functools.reduce(lambda x, y: x * y, args)
[5]:
# ******************************* RMS NORM *******************************
import functools
def liger_rms_norm_forward_meta(X, W, eps, offset, casting_mode):
*n_rows, n_cols = X.shape
n_rows = prod(*n_rows)
# RSTD is always computed/stored in fp32 if we are using Llama or Gemma casting mode
rstd_dtype = (
thunder.dtypes.float32
if casting_mode
in (liger_kernel.ops.rms_norm._CASTING_MODE_LLAMA.value, liger_kernel.ops.rms_norm._CASTING_MODE_GEMMA.value)
else X.dtype
)
Y = TensorProxy(like=X)
RSTD = TensorProxy(like=X, shape=(n_rows,), dtype=rstd_dtype)
BLOCK_SIZE, num_warps = liger_kernel.ops.rms_norm.calculate_settings(n_cols)
return Y, TensorProxy(like=X, shape=(n_rows, n_cols)), RSTD, BLOCK_SIZE, num_warps, casting_mode
liger_rms_norm_forward = liger_ex.register_operator(
"liger_rms_norm_forward", meta=liger_rms_norm_forward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_forward
)
def liger_rms_norm_backward_meta(dY, X, W, RSTD, offset, casting_mode, BLOCK_SIZE, num_warps):
return TensorProxy(like=X), TensorProxy(like=W)
liger_rms_norm_backward = liger_ex.register_operator(
"liger_rms_norm_backward", meta=liger_rms_norm_backward_meta, fn=liger_kernel.ops.rms_norm.rms_norm_backward
)
def rms_norm_meta(x, shape, w, eps):
return thunder.TensorProxy(like=x)
rms_norm = liger_ex.register_operator(
"rms_norm", meta=rms_norm_meta, fn=torch.nn.functional.rms_norm, replaces=torch.nn.functional.rms_norm
)
def rms_norm_grad_transform(x, shape, weight, eps):
Y, X, RSTD, BLOCK_SIZE, num_warps, casting_mode = liger_rms_norm_forward(
x, weight, eps, offset=0.0, casting_mode="llama"
)
dY = get_grad(Y)
dX, dW = liger_rms_norm_backward(
dY, X, weight, RSTD, offset=0.0, casting_mode="llama", BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps
)
dX = dX.view(*x.shape)
put_grads((x, weight), (dX, dW))
return Y
def rms_norm_execution_transform(x, weight, eps):
Y, *_ = liger_rms_norm_forward(x, weight, eps, offset=0.0, casting_mode="llama")
return Y
liger_ex.register_implementation(
rms_norm, execution_transform=rms_norm_execution_transform, grad_transform=rms_norm_grad_transform
)
Testing RMS Norm
Let’s test.
[6]:
hidden_size = 64
example_input = torch.randn(32, 10, hidden_size, device=device, requires_grad=True)
with device:
model = litgpt.model.RMSNorm(hidden_size)
thunder_model = thunder.jit(model, executors=[liger_ex])
ref = model(example_input.clone())
res = thunder_model(example_input.clone())
go = torch.randn_like(ref)
grad_ref, grad_ref_weight = torch.autograd.grad(ref, (example_input, model.weight), go)
grad_res, grad_res_weight = torch.autograd.grad(res, (example_input, model.weight), go)
assert liger_rms_norm_forward in {bsym.sym for bsym in thunder.last_traces(thunder_model)[-1].bound_symbols}
assert liger_rms_norm_backward in {bsym.sym for bsym in thunder.last_backward_traces(thunder_model)[-1].bound_symbols}
assert_close(ref, res)
assert_close(grad_ref, grad_res)
assert_close(grad_ref_weight, grad_res_weight)
[ ]:
RoPE
Next is the RoPE implementation. Liger does both rope applications to query and key in one kernel whereas LitGPT uses two. So we define not only forward and backward and a symbol to capture the litgpt version, but also a small transform fusing the two apply_rope
calls to one liger_rope
.
[7]:
def liger_rope_forward_meta(q, k, cos, sin):
return TensorProxy(like=q), TensorProxy(like=k), cos, sin
liger_rope_forward = liger_ex.register_operator(
"liger_rope_forward",
meta=liger_rope_forward_meta,
fn=liger_kernel.ops.rope.rope_forward,
)
def liger_rope_backward_meta(dq, dk, cos, sin):
return TensorLike(like=dq), TensorLike(like=dk)
liger_rope_backward = liger_ex.register_operator(
"liger_rope_backward",
meta=liger_rope_backward_meta,
fn=liger_kernel.ops.rope.rope_backward,
)
def liger_rope_grad_transform(q, k, cos, sin):
q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)
q_out_grad = get_grad(q_out)
k_out_grad = get_grad(k_out)
dq, dk = liger_rope_backward(q_out_grad, k_out_grad, cos, sin)
put_grads((q, k), (dq, dk))
return q_out, k_out
def liger_rope_execution_transform(q, k, cos, sin):
q_out, k_out, _, _ = liger_rope_forward(q, k, cos, sin)
return q_out, k_out
def liger_rope_impl(q, k, cos, sin):
qr, kr, _, _ = liger_rope_forward(q, k, cos, sin)
return qr, kr
liger_rope = liger_ex.register_operator("liger_rope", fn=liger_rope_impl, like=liger_rope_impl)
liger_ex.register_implementation(
liger_rope,
execution_transform=liger_rope_execution_transform,
grad_transform=liger_rope_grad_transform,
)
def litgpt_apply_rope_meta(x, cos, sin):
return TensorProxy(like=x)
litgpt_apply_rope = liger_ex.register_operator(
"litgpt_apply_rope", fn=litgpt.model.apply_rope, meta=litgpt_apply_rope_meta, replaces=litgpt.model.apply_rope
)
class MergeRopeTransform(thunder.core.transform_common.Transform):
def transform_traces_pre_prologue(self, prologue_trace, compute_trace, epilogue_trace, **kwargs):
new_compute_trace = thunder.core.trace.from_trace(compute_trace)
bound_symbols = compute_trace.bound_symbols[:]
while bound_symbols:
bsym = bound_symbols.pop(0)
if bsym.sym == litgpt_apply_rope:
for i, bsym2 in enumerate(bound_symbols):
assert not any(o is bsym.output for o in bsym2.flat_outs)
if bsym2.sym == litgpt_apply_rope:
break
bsym2 = bound_symbols.pop(i)
assert bsym2.sym == litgpt_apply_rope
output = (bsym.output, bsym2.output)
args = (bsym.args[0], bsym2.args[0], *bsym.args[1:])
new_compute_trace.bound_symbols.append(bsym.from_bsym(args=args, output=output, sym=liger_rope))
else:
new_compute_trace.bound_symbols.append(bsym.from_bsym())
new_compute_trace.set_provenance(thunder.core.trace.TraceProvenance(self.__class__))
return prologue_trace, new_compute_trace, epilogue_trace
Test
We test with a scaled-down Llama.
[8]:
cfg = litgpt.Config.from_name("Llama-3.2-1B", n_layer=1)
with device:
m = litgpt.GPT(cfg)
m.max_seq_length = 1024
m.set_kv_cache(1)
inp = torch.arange(1, 6, dtype=torch.int64)[None]
inp_pos = torch.arange(1, 6, dtype=torch.int64)
jm = thunder.jit(m, executors=(liger_ex,), transforms=(MergeRopeTransform(),))
res = jm(inp, inp_pos)
go = torch.randn_like(res)
(grad_res,) = torch.autograd.grad(res, jm.get_parameter("transformer.wte.weight"), go)
ref = m(inp, inp_pos)
(grad_ref,) = torch.autograd.grad(ref, m.get_parameter("transformer.wte.weight"), go)
assert_close(res, ref)
assert_close(grad_res, grad_ref)
assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
SwiGLU
[9]:
def liger_swiglu_forward_meta(a, b):
return TensorProxy(like=a)
def liger_swiglu_forward_impl(a, b):
_, _, res = liger_kernel.ops.swiglu.swiglu_forward(a, b)
return res
liger_swiglu_forward = liger_ex.register_operator(
"liger_swiglu_forward",
meta=liger_swiglu_forward_meta,
fn=liger_swiglu_forward_impl,
)
def liger_swiglu_backward_meta(a, b, grad_res):
return TensorProxy(like=a), TensorProxy(like=b)
liger_swiglu_backward = liger_ex.register_operator(
"liger_swiglu_backward",
meta=liger_swiglu_backward_meta,
fn=liger_kernel.ops.swiglu.swiglu_backward,
)
def liger_swiglu_gradient_transform(a, b):
res = liger_swiglu_forward(a, b)
grad_res = get_grad(res)
grad_a, grad_b = liger_swiglu_backward(a, b, grad_res)
put_grads((a, b), (grad_a, grad_b))
return res
liger_ex.register_implementation(
liger_swiglu_forward, grad_transform=liger_swiglu_gradient_transform, execution_transform=liger_swiglu_forward
)
class FuseSwigLUTransform(thunder.core.transform_common.Transform):
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
_, consumers = thunder.core.utils.producers_and_consumers(computation_trace)
new_computation_trace = thunder.core.trace.from_trace(computation_trace)
bsyms_to_skip = set()
for b in computation_trace.bound_symbols:
if b in bsyms_to_skip:
continue
new_bsym = b
if b.sym == thunder.torch.silu:
c = consumers[b.output]
if len(c) == 1 and c[0].sym == thunder.torch.mul:
(mul,) = c
mul_l, mul_r = mul.args
if mul_l is b.output:
other = mul_r
else:
other = mul_l
new_bsym = b.from_bsym(
sym=liger_swiglu_forward, output=mul.output, args=(b.args[0], other), subsymbols=[]
)
bsyms_to_skip.add(mul)
new_computation_trace.bound_symbols.append(new_bsym)
new_computation_trace.set_provenance(thunder.core.trace.TraceProvenance("constructed by FuseSwigLU"))
return prologue_trace, new_computation_trace, epilogue_trace
[ ]:
Fused Linear and Cross Entropy
[10]:
def liger_fused_linear_cross_entropy_forward_meta(
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction="mean"
):
logits = thunder.torch.linear(_input, weight, bias)
loss = thunder.torch.cross_entropy(
logits, target, ignore_index=ignore_index, label_smoothing=label_smoothing, reduction=reduction
)
grad_input = TensorProxy(like=_input)
grad_weight = TensorProxy(like=weight)
grad_bias = None if bias is None else TensorProxy(like=bias)
return loss, grad_input, grad_weight, grad_bias
liger_fused_linear_cross_entropy_forward = liger_ex.register_operator(
"liger_fused_linear_cross_entropy_forward",
fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_forward,
like=liger_fused_linear_cross_entropy_forward_meta,
)
def liger_fused_linear_cross_entropy_backward_meta(grad_output, grad_input, grad_weight, grad_bias):
return (
TensorProxy(like=grad_input),
TensorProxy(like=grad_weight),
(TensorProxy(like=grad_bias) if grad_bias is not None else None),
)
liger_fused_linear_cross_entropy_backward = liger_ex.register_operator(
"liger_fused_linear_cross_entropy_backward",
fn=liger_kernel.ops.fused_linear_cross_entropy.fused_linear_cross_entropy_backward,
meta=liger_fused_linear_cross_entropy_backward_meta,
)
def liger_fused_linear_cross_entropy_grad_transform(
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0, reduction="mean"
):
loss, grad_input_1, grad_weight_1, grad_bias_1 = liger_fused_linear_cross_entropy_forward(
_input,
weight,
target,
bias=bias,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
reduction=reduction,
)
grad_loss = get_grad(loss)
grad_input, grad_weight, grad_bias = liger_fused_linear_cross_entropy_backward(
grad_loss, grad_input_1, grad_weight_1, grad_bias_1
)
put_grads((_input, weight, target), (grad_input, grad_weight, grad_bias))
return loss
liger_ex.register_implementation(
liger_fused_linear_cross_entropy_forward,
grad_transform=liger_fused_linear_cross_entropy_grad_transform,
execution_transform=liger_fused_linear_cross_entropy_forward,
)
class FuseLinearCrossEntropyTransform(thunder.core.transform_common.Transform):
def transform_traces_pre_prologue(self, prologue_trace, computation_trace, epilogue_trace, **kwargs):
_, consumers = thunder.core.utils.producers_and_consumers(computation_trace)
new_computation_trace = thunder.core.trace.from_trace(computation_trace)
bsyms_to_skip = set()
for b in computation_trace.bound_symbols:
if b in bsyms_to_skip:
continue
new_bsym = b
if b.sym == thunder.torch.linear:
c = consumers[b.output]
if len(c) == 1 and c[0].sym == thunder.torch.cross_entropy:
(ce,) = c
assert not ce.kwargs
assert not b.kwargs
assert ce.args[0] is b.output
inp, weight, bias = b.args
_, targets, ce_weight, size_average, ignore_index, reduce, reduction, label_smoothing = ce.args
assert ce_weight is None
assert size_average is None
assert reduce is None
new_bsym = b.from_bsym(
sym=liger_fused_linear_cross_entropy_forward,
output=ce.output,
args=(inp, weight, targets, bias, ignore_index, label_smoothing, reduction),
subsymbols=[],
)
bsyms_to_skip.add(ce)
new_computation_trace.bound_symbols.append(new_bsym)
new_computation_trace.set_provenance(
thunder.core.trace.TraceProvenance("constructed by FuseLinearCrossEntropy")
)
return prologue_trace, new_computation_trace, epilogue_trace
[11]:
def apply_eye_meta(x):
return thunder.TensorProxy(like=x)
def apply_eye(mask):
mask = mask | torch.eye(mask.shape[-1], dtype=torch.bool, device=mask.device)[None, None]
return mask
t_apply_eye = liger_ex.register_operator("t_apply_eye", fn=apply_eye, meta=apply_eye_meta, replaces=apply_eye)
def apply_eye_grad_transform(x):
return t_apply_eye(x)
liger_ex.register_implementation(
t_apply_eye, execution_transform=apply_eye_grad_transform, grad_transform=apply_eye_grad_transform
)
class GPTForFineTuningLastToken(litgpt.model.GPT):
def forward(self, idx: torch.Tensor, *, mask: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
mask = mask.bool()
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}.")
attn_mask = (
litgpt.model.build_mask_cache(mask.shape[-1], mask.device).expand(4, -1, -1, -1) * mask[:, None, None, :]
)
attn_mask = apply_eye(attn_mask)
cos = self.cos[:T]
sin = self.sin[:T]
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if self.config.scale_embeddings:
x = x * torch.tensor(self.config.n_embd**0.5, dtype=x.dtype)
for block in self.transformer.h:
x = block(x, cos, sin, attn_mask, None)
# second to last prediction is the output
x = x[:, -2]
x = self.transformer.ln_f(x)
x = self.lm_head(x) # (b, t, vocab_size)
if self.config.final_logit_softcapping is not None:
x = torch.tanh(x / self.config.final_logit_softcapping) * self.config.final_logit_softcapping
loss = torch.nn.functional.cross_entropy(x, labels)
return loss
cfg = litgpt.Config.from_name("Llama-3.2-1B", n_layer=1)
with device:
m = GPTForFineTuningLastToken(cfg)
m.max_seq_length = 1024
inp = torch.ones(4, 32, dtype=torch.int64)
mask = torch.ones(4, 32, dtype=torch.int64)
labels = torch.ones(4, dtype=torch.int64)
jm = thunder.jit(
m,
executors=(liger_ex,),
transforms=(
MergeRopeTransform(),
FuseSwigLUTransform(),
FuseLinearCrossEntropyTransform(),
),
)
res = jm(inp, mask=mask, labels=labels)
ref = m(inp, mask=mask, labels=labels)
go = torch.randn_like(res)
(grad_res,) = torch.autograd.grad(res, jm.get_parameter("transformer.wte.weight"), go)
(grad_ref,) = torch.autograd.grad(ref, m.get_parameter("transformer.wte.weight"), go)
assert_close(res, ref)
assert_close(grad_res, grad_ref)
assert any(bsym.sym is liger_rope_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rope_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_rms_norm_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_swiglu_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_swiglu_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols)
assert any(bsym.sym is liger_fused_linear_cross_entropy_forward for bsym in thunder.last_traces(jm)[-1].bound_symbols)
assert any(
bsym.sym is liger_fused_linear_cross_entropy_backward for bsym in thunder.last_backward_traces(jm)[-1].bound_symbols
)
End to end example
adapted from a Liger-Kernel example.
Code below is
Copyright 2024 LinkedIn Corporation (BSD 2-CLAUSE LICENSE)
[12]:
if False: # this example has additional dependencies, so we skip it in the CI
import argparse
import math
import os
from dataclasses import _MISSING_TYPE, dataclass
import litgpt
import datasets
import lightning.pytorch as pl
import torch
import transformers
from lightning.pytorch.strategies import DeepSpeedStrategy, FSDPStrategy
from torch.distributed.fsdp import BackwardPrefetch, MixedPrecision
from torch.utils.data import DataLoader
from trl import DataCollatorForCompletionOnlyLM
import warnings
warnings.simplefilter(action="ignore", category=FutureWarning)
_RETAIN_COLUMNS = {"input_ids", "attention_mask", "labels"}
QUESTION = "<Question>"
CHOICES = "<Choices>"
@dataclass
class Args:
model: str = "meta-llama/Llama-3.2-1B-Instruct"
data: str = "cais/mmlu"
output_dir: str = "mmlu_finetuning"
max_length: int = 2048
# for llam3 8B model, deepspeed will OOM with 16 on 8XA100 80G and 8 will OOM with 8XA100 40G
batch_size: int = 4
lr: float = 6e-6
weight_decay: float = 0.05
warmup_ratio: float = 0.1
seed: int = 42
strategy: str = "auto"
num_gpu: int = 1
def warmup_cosine_schedule(warmup_steps, total_steps, min_lr=0):
def lr_lambda(current_step):
if current_step < warmup_steps:
# Linear warmup
return float(current_step) / float(max(1, warmup_steps))
else:
# Cosine annealing
progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
return max(min_lr, 0.5 * (1 + math.cos(math.pi * progress)))
return lr_lambda
def parse_args() -> Args:
parser = argparse.ArgumentParser()
for k, v in Args.__dataclass_fields__.items():
parser.add_argument(f"--{k}", type=v.type, default=v.default)
parsed = parser.parse_args([])
return Args(**{k: v for k, v in vars(parsed).items() if not isinstance(v, _MISSING_TYPE)})
class LanguageModel(pl.LightningModule):
def __init__(self, args: Args, tokenizer):
super().__init__()
self.args = args
self.tokenizer = tokenizer
self.model = None
def configure_model(self):
# https://lightning.ai/docs/pytorch/stable/advanced/model_parallel/fsdp.html#speed-up-model-initialization
if self.model is not None:
return
self.model = GPTForFineTuningLastToken.from_name(self.args.model.rsplit("/", 1)[-1]).to(torch.bfloat16)
self.model.load_state_dict(litgpt.utils.lazy_load(f"checkpoints/{self.args.model}/lit_model.pth"))
self.model = thunder.jit(
self.model,
executors=(liger_ex, *thunder.get_default_executors()),
transforms=(MergeRopeTransform(), FuseSwigLUTransform(), FuseLinearCrossEntropyTransform()),
)
def forward(self, input_ids, attention_mask, labels=None, **kwargs):
return self.model(idx=input_ids, mask=attention_mask, labels=labels, **kwargs)
def training_step(self, batch):
outputs = self.model(
idx=batch["input_ids"],
mask=batch["attention_mask"],
labels=batch["labels"][:, -1],
)
loss = outputs
self.log_dict(
{"train_loss": loss},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
rank_zero_only=True,
sync_dist=False,
)
return loss
def validation_step(self, batch):
outputs = self.model(
idx=batch["input_ids"],
mask=batch["attention_mask"],
labels=batch["labels"][:, -1],
)
loss = outputs
self.log_dict(
{"val_loss": loss},
on_step=True,
on_epoch=True,
prog_bar=True,
logger=True,
rank_zero_only=True,
sync_dist=True,
)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.args.lr,
weight_decay=self.args.weight_decay,
fused=True,
)
lr_lambda = warmup_cosine_schedule(
warmup_steps=self.trainer.estimated_stepping_batches * self.args.warmup_ratio,
total_steps=self.trainer.estimated_stepping_batches,
min_lr=0,
)
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": lr_scheduler, "interval": "step"},
}
class DataModule(pl.LightningDataModule):
def __init__(self, tokenizer, args: Args):
super().__init__()
self.train_dataset = None
self.args = args
self.tokenizer = tokenizer
self.response_template_str = " <Answer>"
response_prompt = tokenizer.encode(f"{self.response_template_str}", add_special_tokens=False)
self.collator = DataCollatorForCompletionOnlyLM(
tokenizer=tokenizer,
response_template=response_prompt,
pad_to_multiple_of=16,
)
def formatting_func(self, example):
output_texts = []
for i in range(len(example["question"])):
choices = ""
for j in range(len(example["choices"][i])):
choices += f"{j+1}. {example['choices'][i][j]}; "
s = "Below is a question and multiple choice answers, choices separated by a semicolon. Please select the best answer for the question. "
s += f"{QUESTION}{example['question'][i]} "
s += f"{CHOICES}{choices} "
s += f"{self.response_template_str}{example['answer'][i]}"
output_texts.append(s)
return output_texts
def tokenize(self, example):
outputs = self.tokenizer(
self.formatting_func(example),
truncation=True,
padding=False,
max_length=self.args.max_length,
)
return {
"input_ids": outputs["input_ids"],
"attention_mask": outputs["attention_mask"],
}
def setup(self, stage) -> None:
if self.train_dataset is not None:
return
dataset = datasets.load_dataset(self.args.data, "auxiliary_train")
flattened_data = [
{
"answer": x["train"]["answer"],
"choices": x["train"]["choices"],
"question": x["train"]["question"],
"subject": x["train"]["subject"],
}
for x in dataset["train"]
][:32]
dataset = datasets.Dataset.from_list(flattened_data)
dataset = dataset.train_test_split(test_size=4, seed=self.args.seed)
train_dataset, val_dataset = dataset["train"], dataset["test"]
self.train_dataset = train_dataset.map(
self.tokenize,
remove_columns=list(set(train_dataset.column_names) - _RETAIN_COLUMNS),
batched=True,
batch_size=1,
num_proc=4,
)
self.val_dataset = val_dataset.map(
self.tokenize,
remove_columns=list(set(val_dataset.column_names) - _RETAIN_COLUMNS),
batched=True,
batch_size=1,
num_proc=4,
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.args.batch_size,
collate_fn=self.collator,
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.args.batch_size,
collate_fn=self.collator,
)
args = parse_args()
pl.seed_everything(args.seed)
os.makedirs(args.output_dir, exist_ok=True)
if args.strategy == "fsdp":
strategy = FSDPStrategy(
auto_wrap_policy=layers,
sharding_strategy="FULL_SHARD",
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
sync_module_states=True,
activation_checkpointing_policy=layers,
mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, reduce_dtype=torch.bfloat16),
forward_prefetch=True,
)
precision = None
elif args.strategy == "deepspeed":
strategy = DeepSpeedStrategy(stage=3)
precision = "bf16-mixed"
elif args.strategy == "ddp":
strategy = "ddp"
precision = "bf16-true"
else:
strategy = "auto"
precision = "bf16-true"
# This only works if you have a snapshot to work from.
trainer = pl.Trainer(
accelerator="cuda",
strategy=strategy,
devices=torch.cuda.device_count() if args.num_gpu is None else args.num_gpu,
default_root_dir=args.output_dir,
log_every_n_steps=1,
max_epochs=1,
precision=precision,
)
tokenizer = transformers.AutoTokenizer.from_pretrained(args.model, padding_side="left", truncation_side="left")
tokenizer.pad_token = tokenizer.eos_token
data_module = DataModule(
tokenizer=tokenizer,
args=args,
)
model = LanguageModel(args=args, tokenizer=tokenizer)
trainer.fit(model, datamodule=data_module)
[ ]: