For LLM research intern preparation Public edition · No proprietary results included
Part 1 — Concepts & Key Formulas
1.1 Causal Language Modeling (CLM)
Core Idea: Predict the next token autoregressively; during training, a causal mask prevents future information leakage.
Loss Function:
Derivation:
- By the chain rule:
- Taking the logarithm and negating gives the cross-entropy loss
- In practice, logits have shape
(batch, seq_len, vocab_size); targets are token ids shifted left by one position
1.2 Softmax & Attention
Scaled Dot-Product Attention:
Why scale by ?
- Assume elements of and are i.i.d. with mean 0 and variance 1
- Then each element of has variance
- When is large, softmax inputs are large → vanishing gradients (softmax saturation)
- Dividing by normalizes the variance to 1, keeping gradients stable
Multi-Head Attention (MHA):
where , , .
GQA / MQA Variants:
- Multi-Query Attention (MQA): All heads share the same ; only differs → KV cache significantly reduced
- Grouped-Query Attention (GQA): The query heads are divided into groups, each sharing ; a middle ground between MHA and MQA
1.3 Position Encoding
Rotary Position Embedding (RoPE):
where .
Properties:
- The inner product depends only on the relative position → naturally encodes relative positions
- No learnable parameters (deterministic)
- Better extrapolation than learned positional embeddings (length can be extended with NTK-aware scaling)
Practical Implementation of RoPE:
For , the rotate_half implementation pairs the two halves for the 2D rotation (the original RoPE paper uses adjacent pairs ; the two differ only by a dimension permutation and are mathematically equivalent).
1.4 LoRA — Low-Rank Adaptation
Motivation: Full fine-tuning of large models has heavy GPU memory overhead (requires storing parameters, gradients, and optimizer states). LoRA freezes the pretrained weights and trains only a low-rank delta.
Key Formula:
where is frozen, , , .
Scaling:
is a scaling hyperparameter, typically set to or .
Parameter Count Analysis:
- Original parameters:
- LoRA parameters:
- Example: → LoRA params , which is of the original
Initialization:
- : Kaiming uniform initialization (or Gaussian)
- : zero initialization → at training start , preserving pretrained output
Merge for Inference:
After merging, inference has no extra overhead.
1.5 Reinforcement Learning from Human Feedback
Reward Model Training (Bradley-Terry):
where is a human-annotated preference pair (preferred vs. rejected).
PPO Objective:
Role of KL Divergence:
- too small → reward hacking (the policy exploits gaps in the reward model)
- too large → policy barely moves (degenerates to the SFT model)
DPO (Direct Preference Optimization):
Bypasses an explicit reward model; derived from the Bradley-Terry model:
DPO Advantages:
- No RL sampling loop (no need to generate responses during training)
- No explicit reward model required
- More stable training, fewer hyperparameters
DPO Limitations:
- The implicit reward may generalize less well than an explicit RM
- More sensitive to preference data quality (no RM "buffer")
- Not easily extended to online RL (requires on-policy sampling for improvement)
1.5b Distributed RLHF Architecture
GPU Utilization Problem in Naive Co-located PPO
The simplest implementation runs the actor, reference model, critic, and reward model all on the same set of GPUs (co-located). The bottleneck is the rollout phase:
┌─────────────────────────────────────────────────────┐
│ Co-located PPO (simplified timeline) │
│ │
│ ──[rollout: actor autoregressive gen]──► ──[train: PPO update]──► │
│ GPU busy with inference trainer busy, actor idle │
└─────────────────────────────────────────────────────┘
- During rollout: the actor generates tokens one at a time (memory-bound; throughput limited by HBM bandwidth); GPU MFU (Model FLOP Utilization) is often low; the trainer (ZeRO/FSDP) sits idle.
- During training: forward + backward is compute-intensive; the rollout worker sits idle.
- Result: the two phases alternate, and overall GPU utilization is a weighted average of each phase's utilization — well below peak utilization for pure training or pure inference.
⚠️ These are not precise measurements; actual MFU depends on model size, batch size, and hardware. The description above captures the qualitative problem; for actual numbers consult the technical reports of the relevant frameworks (OpenRLHF, veRL, etc.).
Disaggregated Rollout + Train Topology
To address the above, disaggregate rollout workers from train workers:
┌──────────────────────────────────────────────────────────────────┐
│ Disaggregated PPO topology │
│ │
│ ┌─────────────────────────┐ ┌──────────────────────────┐ │
│ │ Rollout Workers │ │ Train Workers │ │
│ │ (vLLM / SGLang engine) │ │ (ZeRO-3 / FSDP) │ │
│ │ │ │ │ │
│ │ actor (inference mode) │─────►│ actor (grad update) │ │
│ │ ref model (frozen) │ │ critic (grad update) │ │
│ │ reward model (frozen) │ │ │ │
│ └─────────────────────────┘ └──────────────────────────┘ │
│ │ generate responses + rewards ▲ │
│ │ (rollout buffer) │ weight sync │
│ └────────────────────────────────────────┘ │
│ sync actor weights every N steps (or each rollout) │
└──────────────────────────────────────────────────────────────────┘
Key design points:
- Rollout workers load actor inference weights (FP16/BF16) and use vLLM or SGLang for continuous-batching autoregressive generation.
- Train workers hold the full trainable parameters (including optimizer state) via ZeRO-3 or FSDP, and execute PPO/GRPO gradient updates.
- Weight sync: after train workers complete a batch, they broadcast the latest actor weights to rollout workers. Sync frequency is typically once per PPO iteration (one full rollout + train cycle); some implementations support finer-grained step-by-step sync.
- Ref model / RM: generally reside on the rollout side in inference mode (frozen weights, no gradients), saving memory on the train side.
4-Model Memory Breakdown + How LoRA-in-RL Saves Memory
Standard RLHF involves four models:
| Model | Parameters | Gradients | Optimizer state (AdamW) | Typical location |
|---|---|---|---|---|
| Actor | ✅ (trainable) | ✅ | ✅ (, FP32 ≈ 8 bytes/param) | Train workers |
| Ref model | ✅ (frozen) | ✗ | ✗ | Rollout workers or separate node |
| Critic | ✅ (trainable) | ✅ | ✅ | Train workers (can share GPU with actor) |
| Reward Model | ✅ (frozen) | ✗ | ✗ | Rollout workers |
Single-model memory estimate (using a 7B model as example; order of magnitude, not exact):
where is parameter memory (BF16) and here counts only the FP32 + momenta (8 bytes/param total; the FP32 master copy adds → 12 bytes for the full optimizer state, cf. §1.6). Naively co-locating all 4 models puts memory requirements in the hundreds-of-GB range — a 7B model can still fit on a single 8×80 GB machine, but naive co-location yields low GPU utilization (see above); larger models (e.g., 70B) far exceed a single node.
Memory savings with LoRA-in-RL:
- Only the LoRA adapters of the actor and critic are trained (); pretrained weights are frozen.
- Gradients and optimizer states scale only with the LoRA parameter count. At ~99% reduction in trainable parameters (e.g., rank=16), optimizer state drops from ~56 GB to ~1 GB (order-of-magnitude estimate).
- Trade-off: LoRA's expressiveness is limited by its rank; policy update magnitude during RL may be constrained. In practice, PPO + LoRA has been validated in several public works (exact results depend on task and rank; consult original papers).
Async vs Sync Rollout — Staleness
| Mode | Description | Advantages | Disadvantages |
|---|---|---|---|
| Sync rollout | Training begins only after rollout completes; the next rollout begins only after training completes | No staleness, on-policy | Low GPU utilization (two phases alternate idle) |
| Async rollout | Rollout workers continuously generate; train workers continuously update; weight sync is delayed | High GPU utilization, high throughput | Staleness: rollout uses weights from steps ago; data is off-policy |
Impact of staleness:
- The divergence grows between the policy used during generation and the target being updated.
- PPO's clipped objective tolerates mild off-policy data (via the importance ratio ), but when staleness is large, the variance of the importance ratio grows sharply.
- In practice, many frameworks opt for near-synchronous operation (syncing weights every steps), balancing throughput against staleness.
Reference Implementations: OpenRLHF vs veRL
| Dimension | OpenRLHF | veRL |
|---|---|---|
| Focus | Research-friendly, clean, quick to get started | Production-scale, more aggressive performance optimization |
| Rollout engine | vLLM (deeply integrated) | Supports both vLLM and SGLang |
| Training parallelism | DeepSpeed ZeRO-3 | Supports both FSDP and Megatron-LM TP/PP |
| 4-model scheduling | Supports co-located and disaggregated modes | Hybrid Engine (rollout/train share GPUs with dynamic switching) |
| LoRA-in-RL | ✅ | ✅ |
| Code size | Smaller; clean architecture; good for custom extensions | Larger, but production-complete (checkpoint, fault tolerance) |
| Typical use case | Academic experiments, quick algorithm validation | Large-scale post-training pipelines |
✅ Both are public implementations and can serve as reference skeletons for system design questions. Consult official technical reports and GitHub for specific performance numbers — they vary significantly across versions and hardware. In interviews, cite "order-of-magnitude" rather than exact values.
Throughput Estimation: Rollout vs Train GPU-hours Ratio
⚠️ The following is a qualitative order-of-magnitude analysis. Actual numbers are highly sensitive to model scale, response length, and hardware configuration. In interviews, explicitly say "rough estimate" rather than citing precise benchmarks.
Reasoning framework (using a 7B actor as example):
- Rollout cost: Autoregressive generation is memory-bound; each generated token still requires a full forward pass through all layers (KV cache means only 1 new token is processed per step, but the number of layers is unchanged); throughput is limited by HBM bandwidth. With average response length , rollout compute is roughly proportional to (memory access volume).
- Train cost: Forward + backward ≈ FLOPs ( = sequence length, = parameter count; forward ≈ , backward ≈ , total per token).
- Typical conclusion (order of magnitude): when responses are long (hundreds of tokens), rollout GPU-hours are often comparable to or even greater than train GPU-hours — this is one of the core motivations for disaggregated architectures. If rollout is much faster than training, disaggregation adds limited benefit; if rollout is the bottleneck, allocating more rollout workers is the natural scaling approach.
1.6 Distributed Training Parallelism
Data Parallelism (DP)
Each GPU holds a complete model replica; gradients are synchronized via All-Reduce.
Communication: All-Reduce of parameter gradients per step = (ring all-reduce).
ZeRO (Zero Redundancy Optimizer)
| Stage | Sharded content | Memory per GPU |
|---|---|---|
| ZeRO-1 | Optimizer states (Adam: master + + ) | ~4× parameter count (same parameter memory as DP) |
| ZeRO-2 | + Gradients | ~2× parameter count |
| ZeRO-3 | + Parameters | ~ of parameter count ( = number of GPUs) |
Overhead: ZeRO-3 requires All-Gather of parameters during both forward and backward passes, increasing communication volume (see the note below).
The memory breakdown (mixed-precision Adam, = #params):
| Component | Precision | Bytes/param | Memory |
|---|---|---|---|
| Model params (fp16) | fp16 | 2 | |
| Gradients (fp16) | fp16 | 2 | |
| Adam optimizer states | fp32 | 12 |
The optimizer states = fp32 master-weight copy () + first moment () + second moment (), totaling (a 7.5B model → 120 GB, too big for one GPU). Per-GPU memory across ZeRO stages on GPUs:
| Stage | Sharded | Per-GPU memory | |
|---|---|---|---|
| baseline (DP) | none | ||
| ZeRO-1 | optimizer states | ||
| ZeRO-2 | + gradients | ||
| ZeRO-3 | + parameters |
ZeRO-3 shards all three; communication is ~1.5× of plain DP (forward all-gather params, backward all-gather params + reduce-scatter grads) — trading communication for memory. Source: Rajbhandari et al. 2020, arXiv:1910.02054.
Tensor Parallelism (TP)
Each layer's weight matrix is split column-wise or row-wise across multiple GPUs.
- Column-parallel: ; is split column-wise as ; each GPU computes without communication. If followed by a row-parallel layer, one AllReduce can be fused.
- Row-parallel: ; each GPU computes independently, then one AllReduce.
Megatron-LM design: Column-parallel Linear → GeLU (local) → Row-parallel Linear → AllReduce. The entire MLP block requires only one AllReduce (plus one in the backward pass).
Pipeline Parallelism (PP)
The model is split into layer segments assigned to different machines.
- GPipe strategy: Split the mini-batch into micro-batches; process forward passes sequentially, then backward passes in reverse order.
- 1F1B schedule: Alternately execute 1 forward and 1 backward; bubble fraction is the same as GPipe, but per-stage peak activation memory drops from to micro-batch buffers (bounded by pipeline depth: backward starts earlier, activations freed sooner).
- Bubble rate: , = pipeline stages, = micro-batches.
- Interleaved 1F1B (virtual stages): each device holds non-contiguous layer chunks, dropping the bubble to (~ of non-interleaved) at the cost of extra per-micro-batch p2p comm; Megatron-LM's interleaved/virtual-pipeline schedule (when virtual pipeline stages are enabled).
Sequence Parallelism (SP)
Operations like LayerNorm and Dropout that carry no parameters but occupy activation memory are split along the sequence dimension.
- Ring Attention: the long sequence is split into segments across GPUs; KV is passed via ring communication, reducing activation memory from to .
Practical Guidance:
- Single-node 8 GPUs: DP/ZeRO-2 + TP (NVLink is fast)
- Multi-node: DP/ZeRO-3 + PP (low cross-node bandwidth) + TP (within node)
- Very long contexts: add SP (Ring Attention)
1.7 KV Cache Memory Analysis
Each layer needs to cache and for every token:
- = number of layers, = number of KV heads (fewer than Q heads with GQA), = dimension per head, = sequence length, = concurrent batch (number of requests)
- With FP16, bytes_per_param = 2
PagedAttention (vLLM): KV cache is divided into fixed-size pages (e.g., 16 tokens/page), allocated on demand, eliminating memory fragmentation and supporting more concurrent requests.
1.8 Quantization Fundamentals
Symmetric Quantization:
Asymmetric Quantization:
GPTQ — layer-wise post-training quantization via OBS (Frantar et al., ICLR 2023, arXiv:2210.17323):
- Minimizes per-layer reconstruction error ; follows OBS/OBQ using the inverse Hessian to compensate.
- After quantizing weight , the error is redistributed to the not-yet-quantized weights via , canceling the output shift from quantization.
- Engineering: fixed column order (drops OBQ's greedy per-weight selection) + Cholesky factorization for stability + block updates — quantizes 175B to 3–4 bit in hours.
AWQ — activation-aware weight quantization (Lin et al. 2023, arXiv:2306.00978):
- Observation: weights are not equally important; the ~0.1–1% "salient weights" are identified by activation magnitude (not weight magnitude).
- Method: per-channel scaling of salient channels — multiply weights by and divide the corresponding activations by (, with unchanged), shrinking the relative quant error on salient weights; grid-search per layer. Forward-only, no backprop.
SmoothQuant — migrate quantization difficulty from activations to weights (Xiao et al., ICML 2023, arXiv:2211.10438):
- Problem: activations have per-channel outliers that are very hard to quantize, while weights are smooth and easy.
- Method: per-channel smoothing (same orientation as AWQ above, preserving ), with scale (), moving part of the activation dynamic range into the weights to enable W8A8.
FP8 (Hopper/H100): E4M3 (4 exponent, 3 mantissa, range ±448) for forward weights/activations; E5M2 (5 exponent, 2 mantissa, larger range ±57344) for gradients. Vs asymmetric INT8: drops the zero-point calibration (still needs per-tensor amax scaling), and the float format is more robust to outliers.
KV-cache quantization: at long context the KV cache dominates memory. K has per-channel outliers → quantize per-channel; V is smoother → per-token (e.g., KIVI, arXiv:2402.02750). int8/int4/fp8 cut KV memory 2–4×; int8/fp8 with negligible loss on most tasks, while int4 is task-sensitive (long-context retrieval especially).
1.9 Speculative Decoding
Core Idea: A small draft model predicts tokens in parallel; the target model then verifies them all in a single forward pass.
Accept-Reject Sampling:
- For position : if target model probability draft model probability → accept
- If , accept with probability ; otherwise reject and resample from the normalized residual
- The output distribution is exactly identical to direct sampling from the target model (lossless)
Speedup: Depends on the token acceptance rate between the draft model and the target model. In typical scenarios a – speedup is achievable.
1.10 7-Step ML System Design Framework
| Step | Name | Key points |
|---|---|---|
| 1 | Clarify | Data volume, model scale, QPS, latency SLA, memory budget, success metrics |
| 2 | Data | Sources, cleaning strategy, labeling approach (human / weak supervision / model-generated), data flywheel |
| 3 | Model | Architecture choice, parameter count, Pre-train vs Fine-tune vs RAG, PEFT vs full fine-tuning |
| 4 | Training Infra | Parallelism strategy (DP/TP/PP/SP), memory optimization, batch size, LR schedule |
| 5 | Evaluation | Offline benchmark + human evaluation + safety eval |
| 6 | Serving | Quantization, dynamic batching, KV cache management, latency vs throughput |
| 7 | Monitoring | Quality drift (PPL, accuracy), data distribution shift, safety incidents |
Part 2 — From-Scratch Snippets
The following are minimal educational implementations highlighting core logic, omitting production-level error handling and optimization.
2.1 Scaled Dot-Product Attention
import torch
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
q: torch.Tensor, # (batch, n_heads, seq_q, d_k)
k: torch.Tensor, # (batch, n_heads, seq_k, d_k)
v: torch.Tensor, # (batch, n_heads, seq_k, d_v)
mask: torch.Tensor | None = None, # (batch, 1, seq_q, seq_k) or broadcastable
) -> tuple[torch.Tensor, torch.Tensor]: # returns (output, attn_weights)
d_k = q.size(-1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float("-inf"))
attn_weights = F.softmax(scores, dim=-1)
return torch.matmul(attn_weights, v), attn_weights
2.2 Causal Self-Attention Layer
import torch
import torch.nn as nn
import math
class CausalSelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C = x.shape
qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, B, H, T, d_k)
q, k, v = qkv[0], qkv[1], qkv[2]
# Causal mask: lower triangular
mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = torch.softmax(scores, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, T, C)
return self.out_proj(out)
2.3 LoRA Layer
36 行 / lines
import torch
import torch.nn as nn
import math
class LoRALinear(nn.Module):
"""Wraps a frozen nn.Linear and adds a trainable low-rank delta."""
def __init__(self, base_linear: nn.Linear, rank: int = 16, alpha: float = 32):
super().__init__()
self.base = base_linear
self.base.weight.requires_grad_(False)
if self.base.bias is not None:
self.base.bias.requires_grad_(False)
in_features = base_linear.in_features
out_features = base_linear.out_features
self.lora_a = nn.Parameter(torch.empty(rank, in_features))
self.lora_b = nn.Parameter(torch.zeros(out_features, rank)) # B init to 0
nn.init.kaiming_uniform_(self.lora_a, a=math.sqrt(5))
self.scaling = alpha / rank
def forward(self, x: torch.Tensor) -> torch.Tensor:
base_out = self.base(x)
lora_out = (x @ self.lora_a.T @ self.lora_b.T) * self.scaling
return base_out + lora_out
def merge(self) -> nn.Linear:
"""Return a new Linear with merged weights (for deployment)."""
merged_weight = self.base.weight.data + (self.lora_b @ self.lora_a) * self.scaling
new_linear = nn.Linear(self.base.in_features, self.base.out_features, bias=self.base.bias is not None)
new_linear.weight.data.copy_(merged_weight)
if self.base.bias is not None:
new_linear.bias.data.copy_(self.base.bias.data)
return new_linear
2.4 Grouped-Query Attention (GQA)
42 行 / lines
import torch
import torch.nn as nn
import math
class GroupedQueryAttention(nn.Module):
def __init__(self, d_model: int, n_q_heads: int, n_kv_heads: int):
super().__init__()
assert n_q_heads % n_kv_heads == 0
self.n_q_heads = n_q_heads
self.n_kv_heads = n_kv_heads
self.n_rep = n_q_heads // n_kv_heads # repeat factor
self.d_k = d_model // n_q_heads
self.wq = nn.Linear(d_model, n_q_heads * self.d_k, bias=False)
self.wk = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
self.wv = nn.Linear(d_model, n_kv_heads * self.d_k, bias=False)
self.wo = nn.Linear(d_model, d_model, bias=False)
@staticmethod
def _repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""Repeat KV heads to match Q heads: (B, n_kv, T, d_k) -> (B, n_q, T, d_k)."""
if n_rep == 1:
return x
B, N, T, D = x.shape
return x[:, :, None, :, :].expand(B, N, n_rep, T, D).reshape(B, N * n_rep, T, D)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
q = self.wq(x).view(B, T, self.n_q_heads, self.d_k).transpose(1, 2)
k = self.wk(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
v = self.wv(x).view(B, T, self.n_kv_heads, self.d_k).transpose(1, 2)
k = self._repeat_kv(k, self.n_rep)
v = self._repeat_kv(v, self.n_rep)
mask = torch.tril(torch.ones(T, T, device=x.device)).unsqueeze(0).unsqueeze(0)
scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
scores = scores.masked_fill(mask == 0, float("-inf"))
attn = torch.softmax(scores, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, T, -1)
return self.wo(out)
2.5 RoPE (Rotary Position Embedding)
import torch
def precompute_rope_freqs(dim: int, max_len: int = 4096, base: float = 10000.0):
"""Precompute sin/cos tables for RoPE."""
freqs = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) # (dim/2,)
t = torch.arange(max_len).float() # (max_len,)
freqs = torch.outer(t, freqs) # (max_len, dim/2)
return torch.cos(freqs), torch.sin(freqs) # each (max_len, dim/2)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
"""Apply RoPE to input tensor.
x: (batch, n_heads, seq_len, d_k)
cos, sin: (seq_len, d_k/2)
"""
d_half = x.shape[-1] // 2
x1 = x[..., :d_half]
x2 = x[..., d_half:]
cos = cos.unsqueeze(0).unsqueeze(0) # broadcast
sin = sin.unsqueeze(0).unsqueeze(0)
out1 = x1 * cos - x2 * sin
out2 = x2 * cos + x1 * sin
return torch.cat([out1, out2], dim=-1)
2.6 DPO Loss
import torch
import torch.nn.functional as F
def dpo_loss(
policy_logps_w: torch.Tensor, # log pi_theta(y_w | x)
policy_logps_l: torch.Tensor, # log pi_theta(y_l | x)
ref_logps_w: torch.Tensor, # log pi_ref(y_w | x)
ref_logps_l: torch.Tensor, # log pi_ref(y_l | x)
beta: float = 0.1,
) -> torch.Tensor:
"""Direct Preference Optimization loss."""
log_ratio_w = policy_logps_w - ref_logps_w
log_ratio_l = policy_logps_l - ref_logps_l
logits = beta * (log_ratio_w - log_ratio_l)
return -F.logsigmoid(logits).mean()
2.7 KV Cache Wrapper (Minimal)
import torch
class KVCache:
"""Minimal KV cache (batch=1) for autoregressive generation."""
def __init__(self, max_len: int, n_heads: int, d_k: int, device: torch.device):
self.max_len = max_len
self.k = torch.zeros(1, n_heads, max_len, d_k, device=device)
self.v = torch.zeros(1, n_heads, max_len, d_k, device=device)
self.cur_len = 0
def append(self, new_k: torch.Tensor, new_v: torch.Tensor):
"""Append new KV from one decoding step."""
seq_len = new_k.shape[2]
self.k[:, :, self.cur_len:self.cur_len + seq_len] = new_k
self.v[:, :, self.cur_len:self.cur_len + seq_len] = new_v
self.cur_len += seq_len
def get(self):
"""Return the current cached KV (trimmed to cur_len)."""
return self.k[:, :, :self.cur_len], self.v[:, :, :self.cur_len]
2.8 Symmetric INT8 Quantize / Dequantize
import torch
def symmetric_quantize_int8(weight: torch.Tensor):
"""Per-tensor symmetric INT8 quantization."""
scale = weight.abs().max() / 127.0
w_q = torch.round(weight / scale).clamp(-127, 127).to(torch.int8) # symmetric: [-127,127] matches /127 scale
return w_q, scale
def symmetric_dequantize_int8(w_q: torch.Tensor, scale: float) -> torch.Tensor:
"""Dequantize INT8 back to float."""
return w_q.float() * scale
2.9 Tensor-Parallel Linear (Column / Row)
32 行 / lines
import torch
import torch.nn as nn
# Megatron tensor-parallel Linear hinges on a pair of conjugate comm operators f / g:
# f: forward identity, backward all-reduce; g: forward all-reduce, backward identity.
# Below we simulate 2-way sharding in a single process (all-reduce -> sum over shards,
# all-gather -> cat) and verify TP equals an unsharded Linear exactly.
def column_parallel(X, W, b, n_shards=2):
"""Column-parallel: split W along output dim; each rank computes X·W_iᵀ locally;
output is feature-sharded (no comm needed to obtain the partial output)."""
Ws, bs = torch.chunk(W, n_shards, dim=0), torch.chunk(b, n_shards, dim=0)
outs = [X @ Wi.T + bi for Wi, bi in zip(Ws, bs)] # local matmul per rank
return torch.cat(outs, dim=-1) # validation-only concat; in a fused MLP the output stays sharded (no gather)
def row_parallel(X, W, b, n_shards=2):
"""Row-parallel: input X is feature-sharded; split W along input dim;
each rank computes a partial product, summed via all-reduce."""
Xs, Ws = torch.chunk(X, n_shards, dim=-1), torch.chunk(W, n_shards, dim=1)
partial = [Xi @ Wi.T for Xi, Wi in zip(Xs, Ws)]
return sum(partial) + b # g operator: all-reduce (sum); bias added once
# --- Verify: TP equals a plain Linear ---
torch.manual_seed(0)
B, d_in, d_out = 4, 8, 6
X = torch.randn(B, d_in)
ref = nn.Linear(d_in, d_out)
W, b = ref.weight.data, ref.bias.data # W: (d_out, d_in), b: (d_out,)
Y_ref = ref(X)
print("column-parallel max err:", (column_parallel(X, W, b) - Y_ref).abs().max().item()) # ~0
print("row-parallel max err:", (row_parallel(X, W, b) - Y_ref).abs().max().item()) # ~0
A Megatron MLP chains column-parallel → GeLU (local) → row-parallel, so the whole block needs only one all-reduce in the forward pass (one in backward) — minimizing communication.
Part 3 — Interview Questions
L1 — Basic
Q1: What is the time complexity of self-attention in a Transformer? How can it be reduced?
Answer: Standard self-attention has time complexity ( = sequence length, = dimension) because it must compute an attention matrix. Methods to reduce it include:
- FlashAttention: Does not change the mathematical result; reduces wall-clock time by tiling and recomputation to minimize HBM accesses
- Sparse Attention: Longformer/BigBird uses local windows + global tokens, reducing complexity to
- Linear Attention: Approximates softmax with a kernel function, reducing complexity to , but usually with some accuracy loss
Follow-up: Why is FlashAttention not considered an "approximate" attention? What low-level optimizations does it perform?
FlashAttention loads Q, K, V in tiles into SRAM, computes the softmax online normalization in SRAM (by maintaining a running max and running sum), then writes the result back to HBM. It is mathematically equivalent to standard attention — it merely reduces the number of HBM read/write operations.
Q2: What is Layer Normalization? How does it differ from Batch Normalization?
Answer:
- BatchNorm: Computes mean and variance across the batch dimension for each feature. Requires maintaining running mean/var during training and uses fixed statistics at inference. Sensitive to batch size; not suitable for variable-length sequences.
- LayerNorm: Computes mean and variance across the feature dimension for each sample (each token is normalized independently); does not depend on batch statistics. The standard choice in Transformers.
Follow-up: What advantage does RMSNorm have over LayerNorm?
RMSNorm removes the mean-centering step and only performs variance normalization: . Slightly less computation, similar practical performance; adopted by the LLaMA series.
Q3: What is gradient clipping? Why is it nearly universal in LLM training?
Answer: Gradient clipping constrains the norm of the gradient to a threshold: In LLM training, a small number of anomalous samples can produce extremely large gradients (gradient spikes), causing abrupt parameter changes or NaN values. Gradient clipping (typically ) is the standard safeguard against training collapse.
Follow-up: How can you tell whether the gradient clipping threshold is set appropriately?
Monitor the frequency of clipping events in the training log. Occasional triggering (< 5% of steps) is normal; frequent triggering suggests the learning rate may be too large; never triggering during otherwise stable training suggests the threshold may be too loose.
Q4: What are warmup and cosine decay? Why are they commonly used in LLM pre-training?
Answer:
- Warmup: Linearly increase the learning rate at the beginning of training (typically for the first 1%–3% of steps), because with random initialization the gradient direction is unstable and a large LR may cause divergence.
- Cosine decay: After warmup, the LR follows a cosine curve from peak to near zero:
Follow-up: How does a WSD (Warmup-Stable-Decay) schedule differ from a cosine schedule?
WSD maintains a constant LR after warmup (the stable phase), then decays rapidly at the end. Its advantage is that mid-training checkpoints have good quality, making it suitable for scenarios that need to evaluate downstream tasks from intermediate checkpoints.
Q5: Explain the basic principle of FlashAttention and why it is faster.
Answer: The core of FlashAttention is an IO-aware algorithm design:
- Split Q, K, V into small blocks, each small enough to fit in GPU SRAM (on-chip memory)
- Compute softmax and matrix multiplication within SRAM
- Use online softmax (maintaining row-wise running max and sum statistics) to avoid needing global information for the softmax
- Avoid writing the attention matrix to HBM (GPU memory), thereby reducing HBM reads and writes
Source of speedup: standard attention must write/read the attention matrix from HBM, making HBM bandwidth the bottleneck. FlashAttention concentrates computation in SRAM, reducing HBM access from to ( = SRAM size).
Follow-up: How large is the benefit of FlashAttention for training vs inference respectively?
In training, the main saving is in HBM accesses during backpropagation (the attention matrix is not stored in the forward pass; it is recomputed as needed during backward). In inference, the benefit is primarily in the prefill stage (long prompts); the benefit in the decode stage (single token per step) is smaller.
Q6: What is PEFT (Parameter-Efficient Fine-Tuning)? Name at least three methods and briefly describe each.
Answer:
- LoRA / QLoRA: Insert a low-rank bypass () alongside a weight matrix; only the bypass parameters are trained. QLoRA further quantizes the base weights to 4-bit.
- Prefix Tuning: Prepend trainable "virtual token" vectors to the keys and values of each attention layer.
- Adapter: Insert a small MLP bottleneck (down-projection → nonlinearity → up-projection) between Transformer sublayers; only adapter parameters are trained.
- Prompt Tuning: Prepend a small number of trainable soft prompt vectors to the input embeddings (only at the input layer).
Follow-up: What is the trade-off between parameter efficiency and expressiveness for these methods?
Fewer parameters save more memory but cap the expressiveness. LoRA, acting directly on weight matrices, typically outperforms adapters and prefix tuning at similar parameter counts. In extreme scenarios (e.g., only tens of examples), fewer parameters can actually prevent overfitting.
Q7: What is the difference between continuous batching and static batching?
Answer:
- Static batching: Collects a batch of requests and waits until all requests finish generation before releasing the batch. If one request is much shorter than others, GPU resources are wasted once it completes (padding and idle waiting).
- Continuous batching (iteration-level scheduling): After each generation step (one token), checks whether any request has finished; completed requests are immediately replaced by new ones. GPU utilization is significantly improved.
Follow-up: Are PagedAttention and continuous batching used together?
Yes. Continuous batching answers "when to schedule requests"; PagedAttention answers "how to allocate KV cache memory" — it splits the KV cache into fixed-size pages, allocates on demand, and avoids memory fragmentation caused by variable request lengths.
L2 — Intermediate
Q8: Explain what each of the three ZeRO stages does, and their respective communication overhead.
Answer:
- ZeRO-1: Shards optimizer states (AdamW's FP32 master copy + + , 12 bytes/param) across GPUs. Each GPU holds only of the optimizer state; after the update, parameters are gathered via AllGather.
- ZeRO-2: On top of ZeRO-1, also shards gradients. Each GPU keeps only of the gradients (the rest are discarded after Reduce-Scatter).
- ZeRO-3: Also shards parameters. During forward and backward passes, parameters are All-Gathered on demand and released after use.
Communication: ZeRO-1/2 match standard DP ( per step). ZeRO-3 needs an AllGather of params in both forward and backward, plus a Reduce-Scatter of grads in backward — total (~ plain DP's ).
Follow-up: When is ZeRO-2 better than ZeRO-3?
When model parameters fit on a single GPU but optimizer states do not, ZeRO-2 has lower communication overhead. A typical scenario is fine-tuning a medium-scale model (e.g., 7B–13B) with gradient checkpointing + ZeRO-2.
Q9: What is the concrete RLHF-PPO training loop? Why is KL penalization needed?
Answer: Each RLHF-PPO step:
- Sample a batch of prompts; generate responses with the current policy
- Score each (prompt, response) pair with the reward model
- Compute the KL penalty using the reference policy
- Compute advantages (typically with GAE)
- Update the policy with the PPO clip objective (multiple mini-batch update rounds)
Why KL penalization is needed: Without a KL constraint, the policy quickly drifts into out-of-distribution (OOD) blind spots of the reward model — generating responses that score high under the RM but that humans actually dislike (reward hacking). The KL penalty keeps the policy close to (i.e., the SFT model).
Follow-up: Can you give a concrete example of reward hacking?
For example, if the reward model favors long answers (because good answers in the training data tended to be longer), the policy may learn to produce very long, repetitive responses to any question to obtain a high score — even though a human evaluator would find them verbose and unhelpful.
Q10: How do you prevent catastrophic forgetting during instruction tuning?
Answer: Common approaches:
- Replay / mixed training: Mix some general instruction data or pretraining data into the SFT data
- LoRA / PEFT: Only a small number of parameters are updated; pretrained knowledge is preserved in the frozen base weights
- Regularization: Methods like EWC (Elastic Weight Consolidation) penalize large deviations of important parameters
- Low learning rate: When doing full fine-tuning, use an LR 1–2 orders of magnitude lower than pre-training
Follow-up: How do you quantify the degree of catastrophic forgetting?
Evaluate on both a general benchmark (e.g., MMLU, HellaSwag) and the target task benchmark before and after fine-tuning. If performance on the general benchmark drops by more than a few percentage points, significant forgetting has occurred.
Q11: What is the core difference between GPTQ and AWQ?
Answer:
- GPTQ (Optimal Brain Quantization series): Quantizes layer by layer, using second-order information (inverse Hessian) to minimize the reconstruction error of the layer output. Quantizes column by column; after quantizing each column, compensates the remaining columns.
- AWQ (Activation-Aware Weight Quantization): The core observation is that a small number of "salient channels" (channels with large activations) are critical to output quality. AWQ protects the weights of those channels (e.g., using per-channel scaling to effectively increase precision), rather than quantizing all weights uniformly.
Follow-up: When quantizing to INT4, why is SmoothQuant on activations so important?
Activations often contain outliers (abnormally large values), which stretch the quantization range and reduce effective precision. SmoothQuant migrates activation outliers to weights via a mathematically equivalent per-channel scaling, making the activation distribution more uniform so that both weights and activations can be quantized to lower bit-widths.
Q12: How do Sequence Parallelism and Tensor Parallelism work together?
Answer: In Megatron-LM's design:
- TP splits linear layers (weight matrices of attention and MLP)
- SP (Sequence Parallelism) splits activations of non-linear operations (LayerNorm, Dropout) — along the sequence dimension
- Junction: a TP layer ends with AllReduce (or ReduceScatter); an SP layer also requires communication. Megatron-LM fuses these two communication operations, so no additional communication is introduced.
Benefit: after the TP AllReduce, each GPU holds the full sequence in its activations (redundant). SP eliminates this redundancy; each GPU holds only of the sequence activations, significantly reducing activation memory.
Follow-up: Does SP help with gradient checkpointing?
Yes. SP reduces the activation volume stored on each GPU. Without gradient checkpointing, activation memory drops from to . Even with gradient checkpointing, the temporary memory used during recomputation is reduced proportionally.
Q13: Explain how a reward model is trained in RLHF and how reward model quality is evaluated.
Answer:
- Training: Uses the Bradley-Terry preference model. Given prompt and a response pair ( labeled as better), the reward model loss is . The model is typically initialized from the SFT model, with the language model head replaced by a scalar output head.
- Evaluation metrics:
- Preference prediction accuracy: Accuracy at predicting which of a held-out preference pair is better
- Reward distribution separability: Whether the reward distributions of chosen and rejected responses are sufficiently separated
- Reward hack robustness: Whether the reward still ranks OOD responses (generated by the policy) reasonably
Follow-up: Why does the reward model need to be updated periodically?
Because the policy changes continuously during RL training, the distribution of generated responses gradually shifts away from the RM's training distribution (the SFT model's output distribution). On out-of-distribution data, the RM may produce inaccurate scores, leading to reward hacking.
Q14: What problem does vLLM's PagedAttention solve? What is the mechanism?
Answer:
- Problem: Traditional KV cache pre-allocates a contiguous memory block for each request (at maximum sequence length). But actual generation lengths vary, causing significant memory waste (internal fragmentation) and preventing sharing across requests (external fragmentation).
- PagedAttention mechanism: Borrows the virtual memory paging idea from operating systems:
- KV cache is divided into fixed-size blocks (e.g., each block stores KV for 16 tokens)
- A block table records the mapping from logical blocks to physical blocks for each request
- New blocks are dynamically allocated when new tokens are generated; blocks are freed when a request finishes
- Supports copy-on-write: for beam search candidates sharing the same prefix, KV blocks can be shared
Follow-up: Does PagedAttention negatively affect latency?
The indirect addressing via the block table introduces a small overhead (compared to direct access with contiguous memory), but in practice this overhead is very small (typically < 5%), since attention computation itself is compute-bound or memory-bound and the addressing overhead is not the bottleneck.
Q15: How would you design an offline evaluation harness for an LLM? What aspects need to be considered?
Answer:
- Task abstraction: Each task defines a dataset, prompt template (few-shot format), metric, and output type (generation / loglikelihood)
- Evaluation modes:
- Likelihood-based (e.g., MMLU): compute log-prob for each option, select the maximum
- Generation-based (e.g., GSM8K): generate output then evaluate with rules / code execution
- LLM-as-judge (e.g., MT-Bench): score with a stronger model
- Reproducibility: Fix seeds; record prompt templates and few-shot examples; temperature=0 (or fixed)
- Efficiency: Likelihood tasks suit large batches; generation tasks should be sorted by length to reduce padding
- Contamination detection: Check for n-gram overlap between training data and test sets
Follow-up: Why distinguish between "knowledge" and "reasoning" evaluation?
Because a model may perform well on knowledge-heavy tasks (e.g., factual questions in MMLU) but poorly on reasoning-heavy tasks (e.g., math, code), or vice versa. Evaluating them separately helps pinpoint where the model's capabilities are weak.
Q16: How do you choose an appropriate LoRA rank for LLM fine-tuning?
Answer: Factors to consider:
- Task complexity: Simple classification/extraction tasks usually need only r=4–16; complex reasoning/generation tasks may require r=32–64
- Data size: Use a small rank with little data to prevent overfitting; increase rank when data is plentiful to raise capacity
- Target modules: Applying LoRA only to q_proj and v_proj (fewest parameters) vs. all linear layers (q/k/v/o + MLP gate/up/down) trades parameter count against effectiveness — all linear layers typically yield better results
- Common practice: Start from r=16, α=2r; compare r=8/16/32/64 on the validation set
Follow-up: Can LoRA be combined with QLoRA? Is the accuracy loss large when using 4-bit quantized base weights + LoRA?
Yes — QLoRA is exactly this approach. In practice, 4-bit NF4 quantized base weights + LoRA fine-tuning matches FP16 full fine-tuning within an acceptable margin on most tasks (typically within 1–3 percentage points) while saving enormous memory.
Q-RLHF-A (L2): Why is GPU utilization low in naive co-located PPO? How does a disaggregated architecture solve it?
Answer:
Naive co-located PPO runs rollout and training serially on the same set of GPUs:
- Rollout phase: the actor performs autoregressive inference (memory-bound; throughput limited by HBM bandwidth); the trainer waits idle.
- Train phase: PPO backpropagation is compute-intensive; the rollout worker waits idle.
The two phases alternate, and overall GPU utilization is the weighted average of each phase's utilization — well below peak training utilization.
How the disaggregated architecture solves it:
- Independent rollout workers (vLLM/SGLang engines) continuously generate responses, producing a rollout buffer.
- Independent train workers (ZeRO-3/FSDP) pull data from the buffer and continuously execute PPO/GRPO updates.
- The two sets of workers run concurrently; weights are synchronized at a fixed frequency (typically each iteration).
This way, rollout and training are each optimized for their own workload (inference engine vs. training framework) without blocking each other.
Follow-up 1: How much weight-sync bandwidth is needed between rollout workers and train workers in a disaggregated architecture?
For a 7B BF16 model, one complete weight sync transfers ~14 GB. If syncing once per minute, that is ~14 GB ÷ 60 s ≈ 0.23 GB/s, well below the bandwidth ceiling of NVLink/RDMA (sync overhead is negligible). With LoRA-in-RL, only LoRA parameters need syncing (~100 MB scale), greatly reducing sync overhead.
Follow-up 2: What effect does staleness from async rollout have on PPO? How can it be mitigated?
Staleness causes rollout to generate data with old parameters , introducing an off-policy bias. PPO's importance ratio clip () tolerates mild staleness, but when staleness is large, gradient estimate variance grows and training becomes unstable. Mitigation: control the weight sync frequency (no more than a few mini-batch updates), or use more aggressive importance sampling correction.
L3 — Deep
Q17: How does Megatron-LM's Column-Parallel and Row-Parallel Linear reduce the number of AllReduce operations?
Answer:
Consider two consecutive linear transforms (an MLP block), , :
- Column-Parallel : Split column-wise as ; each GPU computes independently, no communication needed. GeLU is element-wise and naturally separable.
- Row-Parallel : Split row-wise as ; each GPU computes .
- Final AllReduce: (one AllReduce).
Key insight: Column-Parallel output is exactly the input of Row-Parallel; the intermediate nonlinearity (GeLU) is element-wise and requires no communication. Therefore the entire MLP block needs only one AllReduce (forward), and one in the backward pass as well.
Without this design, each layer would require a separate AllReduce, doubling the communication volume.
Follow-up: Can the same trick be applied to the QKV projection and output projection in the attention block?
Yes. QKV projection uses Column-Parallel (outputs are distributed to each head, which naturally splits column-wise); the output projection uses Row-Parallel, followed by AllReduce. The entire attention block also needs only one AllReduce.
Q18: Why is speculative decoding lossless? Derive the acceptance probability.
Answer:
Let target model distribution be and draft model distribution be .
Accept-reject sampling:
- Sample token from
- If : accept (probability 1)
- If : accept with probability
Total probability of accepting token :
- Sampled from and accepted:
- Sampled from , rejected, then resampled to : more complex but derivable
Final effective probability:
The second term = reject probability normalized residual; the reject probability cancels the denominator to give , so — exactly the target distribution.
Core intuition: when , the draft model "under-sampled" and the deficit is compensated from the residual probability mass after rejection; when , rejection removes the excess probability.
Follow-up: Where is the efficiency bottleneck of speculative decoding?
The bottleneck is the draft model's acceptance rate. If the distributions of the draft model and target model diverge significantly, the acceptance rate is low, most draft tokens are rejected, and the speedup is poor. Improvements include: Medusa-style multi-head prediction, or selecting a draft model whose distribution is closer to the target model.
Q19: How is DPO derived from the Bradley-Terry preference model?
Answer:
Step 1: The Bradley-Terry model assumes the optimal policy satisfies:
Step 2: Under a KL constraint, the closed-form solution for the optimal policy is:
where is the partition function.
Step 3: Solve for the reward:
Step 4: Substitute into the Bradley-Terry model; cancels in the difference:
Step 5: Replace with the trainable and take the negative log-likelihood to obtain the DPO loss.
Follow-up: DPO's derivation assumes preference data comes from the optimal policy; what practical problems does this assumption cause?
In practice, preference data usually comes from the SFT model (not the optimal policy), which means the reward implicitly learned by DPO may be inaccurate. This is also why online DPO (iterative DPO, where each round generates data with the latest policy) typically outperforms offline DPO.
Q20: What is benchmark saturation in LLM evaluation, and how do you address it?
Answer:
- Problem: When mainstream models score near the ceiling on a benchmark (e.g., >90% on MMLU), discriminability drops. Possible causes include:
- Training data contamination (test set data leaked into the training set)
- Insufficient task difficulty (primarily knowledge retrieval, not deep reasoning)
- Format optimization (models tuned to the benchmark's prompt format)
- Approaches:
- Use harder benchmarks (e.g., MMLU-Pro, GPQA, MATH)
- Use dynamically generated evaluation questions
- Rely on human evaluation (e.g., Chatbot Arena Elo rankings)
- Detect and report data contamination
Follow-up: What are the design philosophies of HELM and lm-evaluation-harness?
HELM (Stanford) emphasizes "comprehensiveness" — covering multiple dimensions (accuracy, calibration, robustness, fairness, efficiency) with detailed documentation and standardized evaluation procedures for each scenario, but extending to new tasks is relatively heavy. lm-evaluation-harness (EleutherAI) emphasizes "flexibility and community contribution" — tasks are defined concisely (config-driven); the community can quickly add new tasks; 400+ tasks provide broad coverage, though standardization is relatively lower.
Q21: Explain the motivation and design of disaggregated serving (prefill/decode separation).
Answer:
Motivation: Prefill (processing the prompt) and decode (generating tokens one at a time) have completely different computational characteristics:
| Characteristic | Prefill | Decode |
|---|---|---|
| Computation type | Compute-bound (large matrix multiplications) | Memory-bound (small batch, heavy KV cache access) |
| GPU utilization | High (compute-intensive) | Low (memory bandwidth bottleneck) |
| Optimal configuration | High-compute GPU | High-bandwidth memory GPU |
Disaggregated Serving Design:
- Prefill nodes: high-compute configuration; process prompts in large batches → generate KV cache
- Decode nodes: high-bandwidth configuration; receive KV cache → generate tokens one at a time
- KV cache is transferred between nodes over a high-speed network (RDMA/NCCL)
Benefits: The two stages can be scaled independently, preventing the memory-bound nature of the decode stage from dragging down the compute utilization of the prefill stage.
Follow-up: How large is the bandwidth requirement for KV cache transfer?
For a 70B model with sequence length 4K and FP16 KV cache, the KV cache per request is on the order of a few hundred MB. If decode nodes need to ingest KV caches from tens of requests per second, tens of GB/s of network bandwidth is required — feasible on modern data-center RDMA networks.
Q22: How do you manage the memory–compute trade-off of gradient checkpointing in distributed training?
Answer:
- Principle: During the forward pass, intermediate activations are not saved; only some "checkpoints" are kept (typically one per layer boundary). During backpropagation, activations are recomputed from the nearest checkpoint.
- Memory: Reduced from ( = activation size per layer) to or ( = number of checkpoints)
- Compute: Approximately 33% extra forward computation (each checkpoint segment is recomputed forward once)
Practical choice:
- Do not use if memory is sufficient (saves time)
- Enable when memory is insufficient but a 33% training slowdown is acceptable
- Can be selectively enabled (e.g., only for certain large layers)
Follow-up: For selective gradient checkpointing, how do you choose which layers to checkpoint?
Typically choose layers with the largest activations (e.g., the attention matrix is an memory consumer). Layers with small activations (e.g., LayerNorm, embedding) are not checkpointed, achieving a better balance between memory savings and computation overhead.
Q23: Explain PPO's clipping mechanism and why it may need adjustment in RLHF.
Answer:
PPO's clipped surrogate objective:
where ; is typically 0.1–0.2.
Purpose: When deviates too far from 1, the clip limits the change in the objective function, preventing excessively large single-step updates.
Special considerations in RLHF:
- In standard RL (games, etc.), the state-action space is large and does not deviate much
- In RLHF, the language model's generation space is exponential and the policy may change rapidly
- Therefore may need to be reduced, or the number of PPO update epochs increased to better exploit each batch of samples
Follow-up: How is the value function loss balanced against the policy loss in PPO?
Typically a weighted sum: , where is the MSE loss of the value function and is an entropy bonus to prevent premature collapse. In RLHF, tuning and is critical for training stability.
Q24: How would you design a system to detect benchmark data contamination?
Answer:
- N-gram overlap detection: Compute the intersection of n-grams (e.g., 8-gram, 13-gram) from the test set against the training data. If the overlap rate exceeds a threshold, flag as potentially contaminated.
- Membership inference: Check whether the model's perplexity on test set samples is anomalously low compared to held-out data; low perplexity may indicate the sample appeared in training.
- Canonical order test: Shuffle the answer choices; if accuracy drops significantly, the model may have memorized the answer at a specific position (suggesting contamination rather than genuine understanding).
- Canary test: Insert unique "canary" sentences into the test set; after training, check whether the model can reproduce them perfectly.
Follow-up: Why might n-gram overlap detection produce false positives?
Because some common knowledge (e.g., "the sun rises in the east") will appear in both training and test sets; n-gram overlap does not mean genuine "memorization." One needs to distinguish "factual public knowledge" from "verbatim copying of specific test samples."
Q-RLHF-B (L3): Design an RLHF training system supporting a 70B actor. Describe the 4-model memory decomposition, rollout/train topology, and how you would choose between LoRA-in-RL vs full parameter updates.
Answer:
Step 1: Clarify
- 70B actor (~140 GB BF16 parameters) + critic (similar or slightly smaller) + ref model + RM
- Naive co-located memory for all 4 models: parameters + optimizer states on the order of 1 TB (not feasible; separation required)
- Goal: run on 8–64 × 80 GB A100/H100 GPUs with throughput that meets a reasonable training schedule
Step 2: 4-model memory decomposition (order-of-magnitude estimates)
| Model | Parameters (BF16) | Gradients | Optimizer (FP32 AdamW) | Deployment strategy |
|---|---|---|---|---|
| Actor (trainable) | ~140 GB | ~140 GB | ~560 GB | Train workers, ZeRO-3 sharding |
| Critic (trainable) | ~140 GB (smaller model possible) | ~140 GB | ~560 GB | Same, or separate ZeRO group |
| Ref model (frozen) | ~140 GB | None | None | Rollout workers, inference mode |
| Reward model (frozen) | few GB–~140 GB | None | None | Rollout workers |
- With full parameter training, the complete training state for actor + critic (parameters + gradients + optimizer) is ~1.5–2 TB scale; ZeRO-3 sharding across train workers requires tens of 80 GB GPUs (exact count depends on whether FP32 master copy, activation, and framework overhead are included).
- With LoRA-in-RL (rank=16–32), trainable parameters drop to of total, optimizer states fall from ~560 GB to a few GB — greatly reducing train worker memory requirements.
Step 3: Topology design
Rollout cluster (inference-optimized) Train cluster (training-optimized)
┌──────────────────────────┐ ┌─────────────────────────┐
│ vLLM / SGLang │ │ ZeRO-3 / FSDP │
│ - actor (FP16 weights) │◄─weights│ - actor (trainable) │
│ - ref model (frozen) │ sync │ - critic (trainable) │
│ - RM (frozen) │ │ │
│ │──data──►│ rollout buffer │
│ continuous rollout, │ │ PPO / GRPO updates │
│ output │ │ │
│ (prompt, resp, reward, │ │ │
│ log_prob, value) │ │ │
└──────────────────────────┘ └─────────────────────────┘
- Rollout and train run concurrently (async) or alternately (sync); weights synced once per iteration.
- Ref model and RM require only inference; placing them on the rollout side saves train-side memory.
Step 4: LoRA-in-RL vs full parameter updates
| Consideration | Favors LoRA-in-RL | Favors full updates |
|---|---|---|
| Memory budget | Tight (fewer GPUs) | Abundant (many GPUs) |
| Required policy change magnitude | Small (conversational style alignment) | Large (complex reasoning improvement) |
| Training stability | More stable (low-rank constraint) | Needs more careful tuning of , clip |
| Reference | OpenRLHF LoRA mode | veRL / Megatron-LM full parameters |
⚠️ The memory figures above are order-of-magnitude estimates (derived from parameter count × bytes/param formulas). Actual values differ substantially due to activations, KV cache, and framework overhead. In interviews, explicitly state "estimate."
Follow-up: How do you decide the resource ratio of rollout to train workers in a disaggregated architecture?
It depends on the ratio of rollout throughput to train throughput. If rollout is the bottleneck (long responses, large batches), add more rollout workers. If train is the bottleneck (large critic, many PPO mini-batches), add more train workers. In practice, first profile the GPU-hours per iteration for each side, allocate proportionally, then adjust based on observed queue utilization.
Q25: Comprehensive design question: Design a complete LLM system for an AI customer service application with 10 million daily active users, from data to deployment.
Answer (high-level overview):
1. Clarify:
- 10M DAU → estimated QPS of ~100–1000 (assuming 1–3 conversation turns per user per day)
- Latency SLA: P95 < 2s (time to first token), P99 < 5s
- Domain adaptation needed (customer service phrasing, product knowledge)
2. Data:
- Historical customer service conversation logs → clean and anonymize → build SFT data
- Periodically sample online bad cases (low ratings, escalated to human agents) → human annotation → feed back into training
- RAG: build a vector knowledge base from product documentation and FAQs
3. Model:
- Base model: 7B–13B scale (balance quality and inference cost)
- SFT (LoRA) fine-tuned on customer service data
- RAG retrieval augmentation: user query → retrieve relevant documents → append to prompt context
4. Serving:
- Quantization: INT8 or INT4 (GPTQ/AWQ) → reduce per-GPU inference cost
- vLLM / TensorRT-LLM deployment, continuous batching + PagedAttention
- Multiple replicas + load balancing, auto-scaling with traffic
5. Monitoring:
- Online metrics: escalation rate, user satisfaction score, average conversation turns
- Quality drift: regularly run eval on a standard test set and monitor score changes
- Safety: apply sensitive word and harmful content filtering to outputs
Follow-up: In this system, what problems do RAG and fine-tuning each solve? Can they replace each other?
Fine-tuning handles "style and format" — making the model respond in a customer-service tone and follow the correct workflow. RAG handles "knowledge and facts" — providing up-to-date product information and company policy. They are complementary, not interchangeable: fine-tuning alone causes hallucinations about product details; RAG alone makes the model sound like a generic assistant rather than a professional customer service agent. The ideal solution combines both.
Appendix: Key Term Glossary
| Chinese | English | Abbreviation |
|---|---|---|
| 因果语言模型 | Causal Language Model | CLM |
| 低秩适配 | Low-Rank Adaptation | LoRA |
| 参数高效微调 | Parameter-Efficient Fine-Tuning | PEFT |
| 人类反馈强化学习 | Reinforcement Learning from Human Feedback | RLHF |
| 直接偏好优化 | Direct Preference Optimization | DPO |
| 奖励模型 | Reward Model | RM |
| 数据并行 | Data Parallelism | DP |
| 张量并行 | Tensor Parallelism | TP |
| 流水线并行 | Pipeline Parallelism | PP |
| 序列并行 | Sequence Parallelism | SP |
| 零冗余优化器 | Zero Redundancy Optimizer | ZeRO |
| 完全分片数据并行 | Fully Sharded Data Parallel | FSDP |
| 键值缓存 | Key-Value Cache | KV Cache |
| 训练后量化 | Post-Training Quantization | PTQ |
| 基于激活感知的权重量化 | Activation-Aware Weight Quantization | AWQ |
| 投机解码 | Speculative Decoding | — |
| 分页注意力 | PagedAttention | — |
| 检索增强生成 | Retrieval-Augmented Generation | RAG |
| 指令微调 | Instruction Tuning / SFT | SFT |
| 灾难性遗忘 | Catastrophic Forgetting | — |
| 知识蒸馏 | Knowledge Distillation | KD |
| 领域自适应预训练 | Domain-Adaptive Pretraining | DAP |
Extended L3
Q26: Explain the IO-aware tiling strategy in FlashAttention. Why does standard attention have a memory access bottleneck, and how does the online softmax trick enable block-wise computation without materializing the full N×N attention matrix?
Standard attention must write the complete attention matrix to HBM (High Bandwidth Memory), making IO the bottleneck. FlashAttention uses GPU SRAM (fast but small) via tiling:
- Split into blocks of size and ; load only one block into SRAM at a time
- For each Q block, iterate over all K/V blocks and compute local attention within SRAM
- Use online softmax by maintaining a running max and running sum : after processing the -th KV block, update the previously accumulated output with a correction factor , avoiding the need for global normalization
IO complexity drops from HBM accesses to ( = SRAM size); memory drops from to (the full attention matrix is never materialized).
Follow-up: FlashAttention's backward pass requires recomputing the attention matrix (recomputation) — what are the similarities and differences with gradient checkpointing? In very-long-sequence settings, what additional parallelization optimizations does FlashAttention v2 introduce?
Q27: How does RoPE's NTK-aware interpolation address the long-sequence extrapolation problem? Why does simple position interpolation lose high-frequency information?
Simple position interpolation (PI) uniformly scales position to . The problem is that RoPE frequencies span multiple orders of magnitude:
- Low dimensions (small ) → high frequency, encoding fine-grained positional relationships at short distances
- High dimensions (large ) → low frequency, encoding coarse positional relationships at long distances
After uniform scaling, the rotation angle in high-frequency dimensions changes too densely; the model cannot distinguish adjacent tokens (high-frequency information is "compressed together"), analogous to applying a low-pass filter to an image and losing edge details.
NTK-aware interpolation rescales the base frequency from to ():
- Low-dimensional high-frequency components are nearly unchanged → preserving local resolution
- High-dimensional low-frequency components are stretched → encoding longer distances
This is analogous to the NTK theory's observation about the difference in learning difficulty between high-frequency and low-frequency features: high-frequency features require higher resolution; low-frequency features can be safely extrapolated.
Follow-up: YaRN further applies temperature scaling to the attention score on top of NTK-aware. What is the motivation? Why is modifying position encoding alone insufficient to fully recover long-context task performance?
Q28: In a Mixture of Experts (MoE) architecture, how do you design an auxiliary load balancing loss to prevent expert collapse? What is the role of the capacity factor?
Expert collapse in MoE: a small number of experts are selected frequently while the rest are nearly idle, wasting the model's effective capacity.
Auxiliary load balancing loss:
- = number of experts, = fraction of tokens routed to expert (discrete statistics), = average probability the router assigns to expert (continuous, differentiable)
- The term encourages both to be uniformly distributed: when an expert is both frequently selected and has high router confidence, the penalty is largest
- is set to a small value to prevent it from dominating the main training loss
Capacity factor (CF): Limits the maximum number of tokens each expert can process in one batch = . CF too small → tokens are dropped (overflow) → information loss; CF too large → computational waste (padding). CF needs to be dynamically adjusted based on the degree of load imbalance.
Follow-up: DeepSeek-MoE proposes fine-grained expert segmentation (splitting large experts into multiple smaller ones) and a shared expert mechanism. How does this design fundamentally mitigate the tension between load balancing (requiring uniformity) and model capability (requiring specialization)?
Q29: How does ZeRO-3's All-Gather communication overlap with forward/backward computation? Why does a naive implementation lead to a significant communication bottleneck?
ZeRO-3 requires All-Gather of the complete parameters for each layer before the forward pass can proceed. Naive implementation: All-Gather → wait → compute → free; communication and computation are serial, and the GPU spends a long time waiting.
Overlap strategy (dependency graph analysis using backward as an example):
Forward: compute(L) ← All-Gather(L) compute(L+1) ← All-Gather(L+1)
↓ can overlap: while compute(L) runs, asynchronously prefetch All-Gather(L+1)
- Forward: While computing layer , asynchronously launch the All-Gather for layer parameters (prefetch). Requirement: compute time for layer ≥ communication time for layer .
- Backward: Similarly, while computing layer gradients, prefetch layer parameters; Reduce-Scatter of layer gradients can also be overlapped with the next layer's computation.
Cost: Simultaneously holding more parameter copies increases (current layer + prefetch layer), adding memory pressure. Total communication is ~ per step (higher than DP's ); when cross-node bandwidth is limited this can become a bottleneck.
Follow-up: At what model scale and hardware conditions does ZeRO-3's communication overhead become unacceptable, making TP (intra-node NVLink) + ZeRO-2 the better choice? Analyze from the perspective of the ratio of communication volume to computation volume.
Q30: DPO's training data is off-policy (generated by ) — what theoretical bias does this introduce? How does iterative DPO mitigate this?
The term in the DPO loss is essentially an importance-weighted reward estimate.
Source of off-policy bias:
- As the divergence between and grows, the variance of importance weights increases and gradient estimates become unstable
- The training data covers a fixed -space anchored to 's support. may have learned to generate responses not seen in training data, but those responses cannot be evaluated by the DPO loss → optimization signal has blind spots
- Analogous to distribution shift in off-policy RL: the further the policy departs from the data-collection policy, the less reliable the estimates
How iterative DPO mitigates this:
- Sample new responses with the current
- Annotate preferences with a reward model or human annotators
- Use the new as the new ; retrain DPO
- Repeat → training data gradually becomes on-policy
Online DPO goes further: within the training loop, it samples 's outputs in real time, scores them with the RM, and immediately updates.
Follow-up: In online DPO, if the reward model itself has a systematic bias (e.g., preferring verbose answers), how would online iteration amplify that problem? What are the mechanistic similarities and differences with reward hacking in PPO?
Q31: How can reward model over-optimization (overoptimization) in RLHF be explained theoretically? How does the divergence between proxy reward and true quality change as KL increases?
This is a manifestation of Goodhart's Law: when a proxy metric is optimized to the extreme, it decouples from the true objective.
Theoretical intuition:
- Let the true reward be , proxy RM , and their difference
- When optimizes in the direction of , it not only improves but also exploits — entering regions where is overestimated
- As increases, the policy departs further from the training distribution, and the generalization error of (i.e., ) grows monotonically
- Qualitative observation: the proxy reward keeps rising; true quality first rises then falls; the crossing point of the two curves is the "over-optimization inflection point"
Factors influencing the divergence rate:
- Larger RM capacity and more diverse preference data → the inflection point appears later
- Larger policy exploration space (longer, more diverse generation) → easier to find reward-hacking paths
Mitigation strategies: KL penalty, RM ensemble (taking the min or variance penalty across multiple RMs), periodic RM updates.
Follow-up: In practice, how does a reward model ensemble exploit agreement and disagreement among multiple RMs? What are the pros and cons of taking the min, the mean, or using disagreement as an uncertainty signal? How does computational cost affect feasibility?
Q32: How does Multi-head Latent Attention (MLA) reduce KV cache memory through low-rank compression? What is the fundamental difference from GQA in terms of compression mechanism?
MLA no longer stores the complete ; instead it stores a low-dimensional latent vector , which is decompressed at inference time:
The KV cache stores only (dimension ); at attention computation time it projects back:
KV cache size drops from to ( can be much smaller than ).
Fundamental difference from GQA:
| Dimension | GQA | MLA |
|---|---|---|
| Compression target | Head dimension (reduce number of KV heads) | Feature dimension (low-rank projection) |
| Compression nature | Discrete, structured (head grouping) | Continuous, flexible (learnable subspace) |
| Cache contents | Actual K, V values (just fewer heads) | Compressed latent vector (requires decompression) |
| Diversity preservation | Directly preserves independent heads | Relies on expressiveness of low-rank subspace |
MLA's advantage: the number of Q heads is no longer directly tied to cache size, enabling large cache compression while retaining many Q heads. Trade-off: inference requires extra projection computation, and the low-rank constraint may limit pattern diversity across heads.
Follow-up: Does MLA's low-rank compression cause different attention heads' patterns to converge (loss of head diversity)? Can the high rank of the projection matrix fully mitigate this risk? In practice, what signals can detect degradation of head diversity?
§A Key Papers Timeline
2018-11 · GPipe — Huang et al., NeurIPS 2019. arXiv:1811.06965 — Foundational pipeline parallelism: partitions layers into stages across devices and splits each mini-batch into micro-batches fed through the pipeline to amortize the bubble, trading recomputation for activation memory so giant models fit across devices.
2019-09 · Megatron-LM — Shoeybi et al., arXiv preprint. arXiv:1909.08053 — Intra-layer tensor parallelism: shards attention and MLP weight matrices column/row-wise across GPUs with one all-reduce each in the forward () and backward () pass, scaling to billions of parameters with no change to model structure.
2019-10 · ZeRO — Rajbhandari et al., SC 2020. arXiv:1910.02054 — Shards the redundant optimizer states / gradients / parameters of data parallelism across ranks (Stages 1/2/3), cutting per-GPU memory from to roughly without incurring tensor-parallel communication cost.
2022-05 · Reducing Activation Recomputation — Korthikanti et al., MLSys 2023. arXiv:2205.05198 — Sequence parallelism + selective recomputation: shards activations of element-wise ops (LayerNorm/Dropout) along the sequence dimension and recomputes only the cheapest-to-redo ops, cutting activation memory ~5×, orthogonal to tensor parallelism.
2022-05 · FlashAttention — Dao et al., NeurIPS 2022. arXiv:2205.14135 — IO-aware exact attention: uses tiling + online softmax to keep the intermediate in SRAM instead of HBM, freeing attention from the memory-bandwidth bottleneck and making memory linear (not quadratic) in sequence length.
2022-09 · FP8 Formats for Deep Learning — Micikevicius et al., arXiv preprint. arXiv:2209.05433 — Defines two 8-bit floating-point encodings for deep learning: E4M3 (range ±448, precision-first, forward pass) and E5M2 (range ±57344, dynamic-range-first, gradients), setting the standard for H100-era FP8 training/inference.
2022-10 · GPTQ — Frantar et al., ICLR 2023. arXiv:2210.17323 — One-shot post-training weight quantization via the OBS (Optimal Brain Surgeon) second-order approximation: quantizes column-by-column and compensates the remaining weights using the inverse Hessian, compressing 175B models to 3–4 bit with little accuracy loss.
2022-11 · SmoothQuant — Xiao et al., ICML 2023. arXiv:2211.10438 — W8A8 quantization: activations have hard-to-quantize outlier channels, so it per-channel "migrates" the difficulty from activations to weights (, ), letting both use INT8 without mixed precision.
2022-11 · Speculative Decoding — Leviathan et al., ICML 2023. arXiv:2211.17192 — A small draft model proposes several tokens that the large target model verifies in parallel, with a carefully designed accept-reject sampling rule guaranteeing the output distribution exactly matches target-only decoding (lossless speedup).
2023-06 · AWQ — Lin et al., MLSys 2024. arXiv:2306.00978 — Activation-aware weight quantization: observes that a tiny fraction of "salient" weight channels dominate error and uses activation magnitude to guide per-channel scaling that protects them, preserving accuracy at 4-bit in a hardware-friendly way.
2023-09 · PagedAttention / vLLM — Kwon et al., SOSP 2023. arXiv:2309.06180 — Manages the KV cache like OS virtual-memory paging: stores KV in non-contiguous blocks allocated on demand, eliminating fragmentation and reservation waste and enabling prefix sharing, greatly raising serving throughput.
2024-02 · KIVI — Liu et al., ICML 2024. arXiv:2402.02750 — Asymmetric 2-bit quantization for the KV cache: quantizes keys per-channel and values per-token (matching their distinct outlier distributions), cutting peak memory (incl. model weights) ~2.6× in long-context inference (KV itself 16-bit→2-bit, ~8× in theory) with near-lossless accuracy.