Back to Blog
SFTSupervised Fine-TuningChat TemplatesResponse MaskingCatastrophic ForgettingFine-Tuning

SFT: What the Model Is Actually Predicting (and the Mask That Decides If It Works)

Published June 13, 20266 min read
SFT: What the Model Is Actually Predicting (and the Mask That Decides If It Works)

SFT: What the Model Is Actually Predicting (and the Mask That Decides If It Works)

Take the simplest possible training example:

{"instruction": "What is the capital of France?", "response": "Paris"}

Most people describe what happens next as "the model learns the answer is Paris." That sentence is the line between someone who has used SFT and someone who understands it — because mathematically, that is not what happens.

SFT is pretraining with curated data

Supervised fine-tuning uses the exact same objective as pretraining: next-token prediction. Nothing about the loss is new. The model minimizes cross-entropy over a sequence of tokens, predicting each token from the ones before it. "Learning Paris" is an emergent side-effect of getting good at predicting the tokens P, a, r, i, s in that context.

What changes from pretraining isn't the loss. It's three things around it: the data (curated instruction/response pairs instead of raw web text), the masking (where loss is applied), and the template (the special-token structure). Get those right and a base "document completer" turns into an "assistant." Get them wrong and the loss curve looks beautiful while the model rambles.

Think of it as a student copying from a worked answer key. The teacher shows the question and the ideal answer; the student practices reproducing the answer token by token. Two subtleties in that picture are exactly where SFT succeeds or fails.

Subtlety one: response masking

Here is one training row after the chat template and tokenization, split into its two parts.

Response masking in one SFT training row

The prompt tokens (<user> What is the capital ... <assistant>) are labelled -100. That magic value tells the loss function to ignore them — no loss is computed on the prompt. Loss is computed only on the response tokens (Paris <end>).

Why mask the prompt? You don't want the model spending its capacity learning to generate the question. You want it to learn the answer, given the question. Grade the student on what they write, not on re-copying the prompt back. In TRL this is train_on_responses_only / DataCollatorForCompletionOnlyLM, and forgetting it is a silent quality drop — the model wastes gradient learning to echo prompts.

This also answers a common interview question — "if SFT is just next-token prediction like pretraining, why does it change behavior so much?" Because behavior is shaped by the data distribution and the masking, not the loss. Same objective, but now every gradient pushes toward "after an assistant header, given an instruction, produce a helpful response," and the mask spends that gradient only on the response.

Subtlety two: the chat template

The second detail is the one I've seen waste the most time, because it fails silently.

Chat template mismatch, the number one silent SFT failure

A model learns that the answer comes after a specific special-token structure — <im_start>assistant, or whatever its template uses. If you train with one template and then infer with a different one (or none), the conditioning is wrong. The model still produces output. No error is thrown. The quality just collapses.

The rule: always use the tokenizer's own apply_chat_template on both sides — training and inference — so the exact token structure matches. Never hand-roll the format string. When someone says "train loss dropped beautifully but the model rambles at inference," template mismatch is the first thing I check.

The loop itself

Stepping back, the whole thing is one loop. Only the masked-loss box is special to SFT; everything else is the universal training loop from the pillar post.

The SFT training loop

Take a batch of prompt/answer pairs, apply the chat template and tokenize, forward pass to logits, compute cross-entropy on the response tokens only, backpropagate, let AdamW update the weights, repeat for 1–3 epochs. The one curve to watch is validation loss against training loss — when train keeps dropping but validation turns up, you're overfitting, and on a narrow dataset that shows up as catastrophic forgetting: the model gets better at your task and worse at everything else.

The knobs that matter

A few hyperparameters carry most of the outcome:

KnobTypicalTrade-off
learning rate~2e-4 (LoRA) · 1e-5–2e-5 (full FT)too high → loss spikes / repetition; too low → underfits
epochs1–3more → better fit but ↑ catastrophic forgetting
max sequence lengthmatch your datalonger → much more memory (attention is quadratic)
response maskingonoff → wastes capacity learning the prompt
warmup ratio5–10% of stepstoo little → early instability

The production failures to name out loud: catastrophic forgetting (fix with LoRA so the base can't be overwritten, lower LR, fewer epochs, mix in general data); "loss great, model bad" (overfit phrasing, template mismatch, or you trained on prompt tokens — read actual generations, not the curve); and template mismatch as above.

How this shaped a real fine-tune

When I built the SFT dataset for a tool-using assistant on Gemma 3 4B, the data was multi-turn conversations, not single Q&A rows — which makes masking more interesting: across a conversation, every assistant turn carries loss and every user/tool turn is masked to -100. Get that wrong and the model learns to generate the user's side of the conversation, which is exactly the failure mode you don't want in an agent. I used Gemma's own chat template via apply_chat_template end to end, kept epochs low to protect the base model's general ability, and read real multi-turn generations rather than trusting the loss curve. The next posts cover the LoRA/QLoRA setup underneath this and the full agent case study.

The takeaway

  • SFT is next-token prediction. Same objective as pretraining; what changes is the data, the masking, and the template.
  • Mask the prompt. Loss on response tokens only — otherwise you waste capacity learning to echo questions. In multi-turn data, mask every non-assistant turn.
  • Match the chat template exactly. Train and infer with the tokenizer's own apply_chat_template. Mismatch fails silently.
  • Curves lie, read generations. Great train loss with bad output means overfit, masking bug, or template mismatch.

Next in the series: LoRA & QLoRA — how to run this exact loop on a model far too big for your GPU, by freezing the base and training a thin adapter.