1. Why Mamba, and why now?
Long-context sequence modeling has an uncomfortable tradeoff:
- We want strong content-dependent reasoning over long windows.
- We want training and inference to remain cheap enough to run on a single, real GPU budget.
Transformers are excellent at content-dependent routing, but full attention scales quadratically with sequence length. In practice, that means memory and latency can become the limiting factor long before model quality does.
Mamba tackles this by combining a recurrent state-space backbone with per-token selectivity. The result is a model family that preserves causal, linear-time sequence processing while still reacting strongly to token content [1].
For this tutorial, treat Mamba as three layers at once:
- A dynamical system with learnable memory.
- A neural module that generates token-dependent update rules.
- A systems implementation where kernel design and memory traffic matter.
By the end, you should be able to move between all three layers comfortably.
Figure: attention-style scaling grows much faster than recurrent SSM scaling as sequence length increases, which is exactly the regime Mamba is designed to target.
Bridge: Why not just use an RNN then?
Classic RNNs are linear-time, but their hidden-state dynamics are often hard to optimize over long horizons. In practice, gradients either vanish, explode, or force heavy gating machinery that still struggles with very long dependencies.
Modern SSM-based models improve this by giving the recurrence a more principled dynamical structure. You can think of it as replacing “generic recurrent matrix” with a better-behaved family of update rules. S4 made that practical for deep sequence modeling, and Mamba adds token-dependent selectivity so the memory policy can change from one token to the next.
So Mamba is not “just another RNN.” It is closer to “a structured recurrent dynamical model with learned, content-aware control inputs.”
2. Notation and shape conventions
This tutorial is shape-heavy. If you keep the dimensions straight, almost every implementation detail becomes easier.
We use the following symbols throughout:
- (B): batch size
- (L): sequence length
- (d_{inner}): inner channel dimension (SSM channels)
- (N): state dimension per channel
- (d_{model}): model embedding width
Main tensors:
- input sequence: (u \in \mathbb{R}^{B \times L \times d_{inner}})
- per-token step sizes: (\Delta \in \mathbb{R}^{B \times L \times d_{inner}})
- continuous dynamics: (A \in \mathbb{R}^{d_{inner} \times N})
- selective input map: (B_t \in \mathbb{R}^{B \times d_{inner} \times N})
- selective readout map: (C_t \in \mathbb{R}^{B \times d_{inner} \times N})
- recurrent state: (x_t \in \mathbb{R}^{B \times d_{inner} \times N})
- output: (y \in \mathbb{R}^{B \times L \times d_{inner}})
In the block implementation, one easy way to avoid confusion is to read dimensions in this order:
- Token-facing tensors are usually
(B, L, *). - Recurrent state tensors are usually
(B, d_inner, N). - Parameter tensors tied to dynamics are usually
(*, d_inner, N).
We will often highlight critical objects with color:
\[\textcolor{#4ecdc4}{A},\;\textcolor{#ff9f6e}{\Delta_t},\;\textcolor{#7ee787}{x_t}\]Optional: Mapping to code variable names
Use this mapping when reading code:
u: scan input stream, shape(B, L, d_inner)delta: per-token step sizes, shape(B, L, d_inner)A: channel/state dynamics matrix, shape(d_inner, d_state)B: input-to-state coefficients, shape(B, L, d_inner, d_state)C: state-to-output coefficients, shape(B, L, d_inner, d_state)D: skip vector, shape(d_inner,)x0: optional initial recurrent state, shape(B, d_inner, d_state)
Two practical tips:
BandChere are not the same objects as continuous-timeB_candC_cfrom textbook control notation.- In this repository,
d_statein code corresponds toNin the math.
3. State space models from first principles
3.1 Continuous-time view
A linear state space model in continuous time is written as
\[\dot{x}(t) = \textcolor{#4ecdc4}{A}x(t) + B_c u(t), \qquad y(t) = C_c x(t) + d_{\mathrm{skip}} \odot u(t).\]Intuition:
- (x(t)) is internal memory.
- (A) controls how memory decays or oscillates.
- (B_c) injects current input into memory.
- (C_c) reads memory into output.
- (d_{\mathrm{skip}}) is a direct skip path from input to output.
For sequence models, we do not observe a smooth timeline. We observe tokens, one discrete step at a time. So we need a discrete recurrence that is stable, expressive, and cheap to compute.
Mamba’s selective scan can be read as exactly that: a discretized, token-indexed state-space recurrence with token-dependent control terms.
3.2 Discrete-time recurrence used by Mamba-style scan
At token (t), with input-dependent step (\Delta_t):
\(\bar{A}_t[b,d,n] = \exp\!\left(\Delta_t[b,d]\;A[d,n]\right),\) \(x_t[b,d,n] = \bar{A}_t[b,d,n]\;x_{t-1}[b,d,n] + \Delta_t[b,d]\;B_t[b,d,n]\;u_t[b,d],\) \(y_t[b,d] = \sum_{n=1}^{N} C_t[b,d,n]\;x_t[b,d,n] + d_{\mathrm{skip}}[d]\;u_t[b,d].\)
Everything above is channelwise and statewise. In tensor form, the implementation treats these as broadcasted elementwise operations over ((B, d_{inner}, N)).
A useful way to read the update:
exp(delta_t * A)decides how much old state survives.delta_t * B_t * u_tinjects the new input into state.C_treads the updated state into the output channel.d_skip * u_tkeeps a direct short path that does not go through recurrent state.
That “survive + inject + read + skip” decomposition is the conceptual core of Mamba.
Bridge: Why parameterize A as negative exponential?
In code, A is parameterized as -exp(A_log). That does two practical things:
- It makes the sign of each mode negative by construction, which pushes dynamics toward decay instead of unbounded growth.
- It lets optimization happen in log-space (
A_log), which gives smoother parameter updates than trying to optimize raw constrained values.
This does not magically guarantee perfect behavior, but it strongly biases the recurrence toward numerically well-behaved regimes, especially early in training.
You can see this directly in MambaBlock:
A = -torch.exp(self.A_log)
4. From S4 to Mamba
S4 made structured state-space models practical for deep long-sequence learning by carefully designing parameterization and computation [2]. Mamba keeps that backbone but changes how updates are controlled.
High-level contrast:
- S4-style dynamics are mostly fixed after training.
- Mamba makes important terms input-dependent at each token.
In Mamba, (\Delta_t, B_t, C_t) are generated from current features, so the recurrence can adapt token by token [1].
Why this matters in practice:
- The model can preserve rare but important tokens over long spans.
- The model can forget quickly when context shifts.
- The model can do this with linear-time recurrent updates rather than quadratic attention.
Mamba: keeps linear recurrent backbone, but adds selective gating in the recurrence itself.
Optional deep dive: Why this matters for language
Language modeling alternates between “hold this for a long time” and “drop this right away.”
Examples:
- If a speaker identity appears, the model may need to keep that signal for many future tokens.
- If punctuation or formatting noise appears, it may be safe to ignore quickly.
- If a local syntactic pattern appears, short-term mixing is often enough.
A fixed recurrence applies one global memory policy everywhere. Mamba’s selective update behaves more like dynamic memory control, where each token influences how aggressively memory is kept or overwritten.
5. The selective scan equations (core of Mamba)
This section is the computational heart of Mamba. Everything else in the block exists to produce the inputs to this recurrence.
At each token (t), you have:
- input (u_t \in \mathbb{R}^{B \times d_{inner}})
- adaptive step size (\Delta_t \in \mathbb{R}^{B \times d_{inner}})
- selective input map (B_t \in \mathbb{R}^{B \times d_{inner} \times N})
- selective readout map (C_t \in \mathbb{R}^{B \times d_{inner} \times N})
- channel/state dynamics (A \in \mathbb{R}^{d_{inner} \times N})
- skip vector (d_{\mathrm{skip}} \in \mathbb{R}^{d_{inner}})
The update is:
\[\bar{A}_t = \exp\big(\textcolor{#ff9f6e}{\Delta_t} \odot \textcolor{#4ecdc4}{A}\big) \in \mathbb{R}^{B\times d_{inner}\times N},\] \[\textcolor{#7ee787}{x_t} = \bar{A}_t \odot x_{t-1} + (\Delta_t \odot B_t) \odot u_t,\] \[y_t[b,d] = \sum_{n=1}^{N} C_t[b,d,n] \, x_t[b,d,n] + d_{\mathrm{skip}}[d] \, u_t[b,d].\]Two things to notice:
- The recurrence is diagonal in state coordinates once (A) is fixed, which makes channel/state updates cheap.
- Selectivity enters through (\Delta_t), (B_t), and (C_t), all token-dependent.
In vectorized code, this becomes:
delta_t = delta[:, t, :].unsqueeze(-1) # (B, d_inner, 1)
u_t = u[:, t, :].unsqueeze(-1) # (B, d_inner, 1)
A_bar_t = torch.exp(delta_t * A_expanded) # (B, d_inner, d_state)
B_bar_t = delta_t * B[:, t, :, :] # (B, d_inner, d_state)
state = A_bar_t * state + B_bar_t * u_t # (B, d_inner, d_state)
y_t = torch.sum(C[:, t, :, :] * state, dim=-1) # (B, d_inner)
Code: mamba4080/ops/selective_scan.py
Implementation trick: Why the vectorized path still loops over time
It is easy to vectorize over batch, channel, and state. It is not easy to remove temporal dependence because each state_t depends on state_{t-1}.
There are theoretical ways to parallelize recurrences, including scan-style associative transforms, but those methods are more complex and usually introduce tradeoffs in memory traffic or numerical behavior.
For a correctness-first baseline, keeping a single explicit loop over sequence length is a very good compromise:
- The code is easy to audit.
- The math maps one-to-one to implementation.
- It provides a strong reference target for optimized kernels.
6. Mamba block end-to-end: intuition, equations, implementation
A single block in this repository follows a faithful Mamba-1 pattern with explicit, inspectable components.
The block has two jobs:
- Produce selective scan parameters from the input stream.
- Wrap the scan with normalization, local mixing, gating, projection, and residual structure.
6.1 Block flow
- Pre-normalize with RMSNorm.
- Project from (d_{model}) to (2d_{inner}), split into (x)-stream and (z)-stream.
- Run depthwise causal convolution on (x)-stream.
- Generate selective parameters ((\Delta, B, C)) from convolved features.
- Run selective scan.
- Gate scan output with (\text{SiLU}(z)).
- Project back to (d_{model}), add residual.
6.2 Equations with shapes
Let (h \in \mathbb{R}^{B \times L \times d_{model}}).
\(\tilde{h} = \mathrm{RMSNorm}(h)\) \([x, z] = W_{in}\tilde{h}, \quad x,z \in \mathbb{R}^{B\times L\times d_{inner}}\) \(x' = \mathrm{SiLU}(\mathrm{DWConv}_{causal}(x))\) \((\Delta, B, C) = f_{param}(x')\) \(y = \mathrm{SelectiveScan}(u=x', \Delta, A, B, C, d_{\mathrm{skip}})\) \(\hat{y} = y \odot \mathrm{SiLU}(z)\) \(\text{out} = W_{out}\hat{y} + h\)
Read this as two coupled paths:
- The
xpath carries recurrence content and produces selective parameters. - The
zpath acts as an output gate to modulate what leaves the recurrent path.
6.3 Code snippets
xz = self.in_proj(x_norm) # (B, L, 2 * d_inner)
x_stream, z_stream = torch.split(
xz, [self.d_inner, self.d_inner], dim=-1
)
x_conv = self.conv(x_stream) # (B, L, d_inner)
x_conv = F.silu(x_conv)
delta, B, C = self._project_selective_params(x_conv)
A = -torch.exp(self.A_log) # (d_inner, d_state)
y = selective_scan(u=x_conv, delta=delta, A=A, B=B, C=C, D=self.D,
implementation=self.ssm_impl)
y = y * F.silu(z_stream)
out = self.out_proj(y) + residual
Code: mamba4080/modules/mamba_block.py
Figure: one Mamba block, showing the split x/z streams, depthwise causal conv on x, selective scan core, gating, and residual recombination.
Optional: Why put depthwise causal convolution before parameter generation
Without the convolution, selective parameters are generated from each token independently. That can work, but it misses short-range local context that is often important for language.
Depthwise causal convolution gives each channel a small local receptive field at low cost:
- “Depthwise” means each channel is convolved independently, so parameter count stays modest.
- “Causal” means no future leakage.
- Local context is injected before generating
delta,B, andC, which often improves quality and stability.
In plain terms: the conv gives the block a local pattern detector before long-range recurrent memory update.
7. Causality and streaming inference
Mamba in this repo supports both:
- full-sequence forward for training
- one-token
step()for streaming generation
For streaming, each layer carries exactly two state tensors:
- SSM state: ((B, d_{inner}, N))
- Conv cache: ((B, d_{inner}, K-1)) for kernel size (K)
The single-token update in MambaBlock.step follows the same math as full forward. That is tested explicitly.
This is one of the most practical advantages of Mamba-like architectures: inference memory per layer is fixed-size in sequence length, unlike full attention caches.
You can think of the two forward modes as:
- Batch mode: compute all tokens in one call for efficient training.
- Step mode: update one token at a time while carrying compact state.
Correctness requirement: these modes must agree numerically (up to tolerance) when fed the same sequence.
Code links:
- Block step:
mamba4080/modules/mamba_block.py - Model step:
mamba4080/models/mamba_lm.py - Test:
mamba4080/tests/test_mamba_block.py
y_full = block(x)
# ... run token-by-token with block.step(...)
torch.testing.assert_close(y_step, y_full, atol=1e-5, rtol=1e-5)
In practice, generation loops follow this pattern:
- initialize per-layer state once,
- feed prompt tokens through
stepto prime state, - sample one token, feed it back, repeat.
That stateful decode path is exactly what keeps memory usage bounded as output length grows.
8. Selective scan implementations in this project
The repository intentionally keeps three paths, each with a different purpose:
naive: scalar-style loops; easiest to reason about.vectorized: practical PyTorch baseline with explicit time loop.triton: optimized CUDA forward path with custom autograd backward and safe fallback.
Keeping all three is not redundant. It gives you:
- A trusted correctness anchor (
naive). - A portable baseline (
vectorized). - A performance path (
triton) you can test against both.
All three are dispatched through one entrypoint.
def selective_scan(..., implementation: Literal["naive", "vectorized", "triton"] = "vectorized"):
if implementation == "naive":
return selective_scan_naive(...)
if implementation == "vectorized":
return selective_scan_vectorized(...)
if implementation == "triton":
# try triton, fallback safely when unavailable/unsupported
Code: mamba4080/ops/selective_scan.py
Practical usage guidance:
- Use
naivewhen validating formulas, debugging gradients, or writing new tests. - Use
vectorizedas the default correctness/performance baseline on any device. - Use
tritonfor CUDA performance runs after parity checks pass.
A useful discipline is to always check a new change in this order:
naivevsvectorized,vectorizedvstriton,- end-to-end training smoke.
9. Triton selective scan: broad view to line-by-line detail
9.1 Broad view
The Triton path accelerates the selective scan forward pass on CUDA and uses an explicit custom backward pass in PyTorch for gradient correctness.
Architecturally:
- Forward recurrence is fused in a Triton kernel.
- Backward recurrence is implemented as explicit reverse-time math in
torch.autograd.Function. - Fallback routes to vectorized implementation when environment or dtype constraints are not satisfied.
Code: mamba4080/ops/selective_scan_triton.py
9.2 Forward kernel anatomy
Key ideas in the forward kernel:
- one Triton program instance handles a
(batch_index, channel_index)pair - recurrence over time remains explicit (
for t in range(seq_len)) - state dimension processed in blocks (
block_n) - output accumulation across state blocks
Memory layout perspective:
uanddeltaare indexed by token and channel.Ais indexed by channel and state.BandCare indexed by token, channel, and state.- For one
(b, d)pair, the kernel walks through tokens and updates a small state vector for that channel.
Representative snippet:
pid = tl.program_id(0)
b = pid // d_inner
d = pid % d_inner
for n_start in tl.static_range(0, d_state, block_n):
state = ...
a_vals = ...
for t in range(seq_len):
u_t = ...
delta_t = ...
b_vals = ...
c_vals = ...
a_bar = tl.exp(delta_t * a_vals)
state = a_bar * state + (delta_t * b_vals) * u_t
contrib = tl.sum(c_vals * state, axis=0)
This is mathematically the same recurrence as the reference path; the speedup comes from fused kernel execution and lower Python/dispatcher overhead.
Line-by-line walkthrough landmarks
- Program ID mapping converts flat launch index into
(batch, channel). - Per-state block loop loads a chunk of
Aand initializes local state registers. - Inner time loop loads
u_t,delta_t,B_t, andC_t. - Recurrence update computes
a_bar, updates state, and accumulates output contribution. - Skip path is added once per token.
- Optional final state writeback stores the last recurrent state for streaming or external reuse.
Code section: mamba4080/ops/selective_scan_triton.py
9.3 Backward recurrence anatomy
Backward first reconstructs the forward state trajectory in float32, then performs reverse-time recursion.
Conceptually, for each time step in reverse order, it computes:
- Gradient wrt current state from output and next-state dependency.
- Gradient wrt readout coefficients
C_t. - Gradient wrt input
u_tfrom both skip path and recurrent path. - Gradient wrt
B_t,delta_t, andA. - Carry gradient to previous state.
This is exactly what you would derive by applying chain rule to the recurrence.
Core pattern:
for t in range(seq_len - 1, -1, -1):
grad_state = grad_state_next + gy_t.unsqueeze(-1) * C_t
grad_C[:, t] = gy_t.unsqueeze(-1) * x_t
grad_u[:, t] += sum(grad_state * (delta_t * B_t), dim=-1)
grad_B[:, t] = grad_state * (delta_t * u_t)
grad_delta[:, t] = sum(grad_state * (...), dim=-1)
grad_A += sum(grad_state * (...), dim=0)
grad_state_next = grad_state * A_bar_t
Local derivatives at one token are:
\[\frac{\partial y_t}{\partial C_t} = x_t,\qquad \frac{\partial y_t}{\partial x_t} = C_t\] \[\frac{\partial x_t}{\partial u_t} = \Delta_t \odot B_t,\qquad \frac{\partial x_t}{\partial B_t} = \Delta_t \odot u_t\] \[\frac{\partial x_t}{\partial \Delta_t} = \left(A \odot e^{\Delta_t \odot A}\right)\odot x_{t-1} + B_t \odot u_t\] \[\frac{\partial x_t}{\partial A} = \Delta_t \odot e^{\Delta_t \odot A}\odot x_{t-1}\]The backward loop in code is just these local derivatives plus reverse-time accumulation.
Important points:
- gradient math is explicit, not delegated to opaque black-box kernels
- compute promoted to float32 in backward for stability
- returned grads cast back to original dtypes
- the backward logic matches the same recurrence structure used in forward, which helps parity debugging
Code path to inspect: _SelectiveScanTritonFunction.backward in
mamba4080/ops/selective_scan_triton.py
9.4 Safety and fallback behavior
Before launching Triton, code checks:
- Triton installed
- compiler availability (
gccorclang) - CUDA device
- supported dtype (
fp16/fp32) - device consistency
Fallback records a reason and routes to vectorized implementation.
if u.dtype not in (torch.float16, torch.float32):
_warn_fallback(f"unsupported dtype {u.dtype}")
return selective_scan_vectorized(...)
Code: mamba4080/ops/selective_scan_triton.py
9.5 Mapping math to memory offsets
For readers implementing kernels themselves, this mapping is critical.
Given contiguous layouts used in this repository:
uanddeltahave shape(B, L, d_inner)and are laid out withd_inneras the fastest axis.Ahas shape(d_inner, d_state).BandChave shape(B, L, d_inner, d_state).
Inside the kernel, with fixed (b, d):
u_base = b * L * d_inner + dpoints to token-0 element for that channel.token_off = u_base + t * d_innermoves along time at fixed channel.ssm_off = ((b * L + t) * d_inner + d) * d_state + n_offsetsindexesB_tandC_tfor state blockn_offsets.
That is why one program ID per (b, d) pair is natural: contiguous time traversal for one channel minimizes index complexity and keeps the recurrence state local.
For backward, the same indexing logic applies in reverse-time order, with the addition that reconstructed forward states are read from a (B, L, d_inner, d_state) buffer.
Optional deep dive: Why not make the entire backward Triton-native immediately?
A fully Triton-native backward is possible, but it is a substantial engineering step:
- You have to manage reverse-time dependencies and accumulation carefully in kernel space.
- Debugging gradient mismatches becomes much harder.
- You need stronger numerical parity tooling to keep confidence high.
The current split design is deliberate. It gives most of the forward speed benefit while keeping backward math explicit and testable. That is a strong intermediate point for a correctness-first open-source project.
10. Building and running the kernel on RTX 4080-class GPUs
This repository is tuned for Ada-class GPUs (SM89), including RTX 4080 Laptop/desktop class devices, while keeping a portable fallback path.
In practice, “tuned” means:
- default presets assume one strong consumer GPU,
- kernels and benchmarks target realistic single-GPU memory envelopes,
- throughput measurements are collected with this hardware class in mind.
10.1 Environment checks
Run:
uv run python doctor.py
Expected high-level checks:
- CUDA available
- compute capability
8.9 (sm89) - compiler available (
gccorclang) - Triton import works
Code: doctor.py
If doctor.py reports missing compiler or Triton import issues, the code still runs, but selective scan acceleration will fall back to the vectorized path.
10.2 Runtime policies used in this repo
In presets and runtime setup:
- TensorFloat-32 can be enabled for CUDA matmul/cudnn
- training preset (
rtx4080) defaults to:ssm_impl = triton- AMP on
amp_dtype = bf16allow_tf32 = true
Code:
These settings are intentionally exposed as flags so you can toggle correctness-first runs (fp32, no AMP) and throughput-first runs (AMP, Triton dispatch) without changing code.
10.3 Why this is “adapted” to RTX 4080
- Benchmark profiles are designed around single-GPU practical sequence lengths and memory limits.
- Defaults target strong throughput without moving to distributed complexity.
- The optimized path is chosen where the hardware gives the biggest payoff (Triton fp32 scan).
Bridge: Is this code only for RTX 4080?
No. The model and tests run on CPU and other CUDA GPUs via the reference/vectorized implementations.
What is hardware-specific is the optimization focus:
- the benchmark sweeps,
- the default presets,
- the speedup claims.
If you run on a different GPU class, you should still expect correctness. You should just re-run benchmarks before drawing performance conclusions.
11. Known limitation: bf16 Triton scan fallback
Current behavior in this codebase:
- Triton selective scan path accepts
fp16andfp32inputs. - If scan inputs are bf16, dispatcher falls back to vectorized PyTorch scan.
Reason in code:
if u.dtype not in (torch.float16, torch.float32):
_warn_fallback(f"unsupported dtype {u.dtype}")
return selective_scan_vectorized(...)
Code: mamba4080/ops/selective_scan_triton.py
Practical interpretation:
triton + fp32: real kernel speedup pathtriton + bf16: correct behavior, but scan kernel itself is not accelerated
This is why benchmark reports include both triton_active and fallback_reason fields.
When reading training throughput numbers, this distinction matters a lot:
- model-level throughput can still improve due to AMP and other kernels,
- selective-scan-specific acceleration is only active in supported dtypes.
Quick troubleshooting checklist:
- If Triton appears unexpectedly slow, inspect
triton_activein benchmark rows. - If fallback happened, inspect
fallback_reasonbefore tuning model hyperparameters. - For kernel-level speedup analysis in the current codebase, use fp32 scan runs.
12. Building the language model around Mamba blocks
MambaLM stacks MambaBlock layers with token embeddings and a final LM head.
The model intentionally stays simple:
- token embedding
- repeated Mamba blocks
- final normalization
- linear LM head
That simplicity is useful for debugging because most behavior can be traced to the block and scan implementations directly.
The base model also avoids extra architectural complications on purpose. In particular, it keeps the sequence stack simple so recurrence behavior is easier to attribute during debugging and benchmarking.
Shapes in forward:
input_ids:(B, L)- embeddings:
(B, L, d_model) - block stack preserves
(B, L, d_model) - logits:
(B, L, vocab_size)
Code:
Snippet:
x = self.token_embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
Streaming mode uses step() and per-layer (ssm_state, conv_state) tuples. The model API keeps this explicit so generation behavior remains easy to inspect and test.
Most impactful model knobs:
d_model: representational width and parameter budget.n_layers: depth and total recurrent capacity.d_state: long-memory capacity per channel.d_conv: size of local pre-SSM context.expand: ratio between model width and inner scan width.
13. Data, training, checkpointing, and generation
13.1 Dataset and vocabulary
TinyShakespeareDataset loads text, builds char-level vocabulary, and samples contiguous chunks for next-token prediction.
Why character-level first:
- zero tokenizer dependencies,
- easy to debug sequence alignment,
- quick iteration for architecture and kernel correctness.
x = tokens[idx : idx + context_len]
y = tokens[idx + 1 : idx + 1 + context_len]
Code: mamba4080/training/char_data.py
Canonical dataset path is data/tiny_shakespeare.txt, with legacy alias support for tiny_shakespear.txt.
The dataset pipeline is deterministic under fixed seeds, which is important for reproducible debugging and benchmark comparisons.
Because this is character-level data, quick qualitative checks are easy: decode a sampled batch and verify token continuity by eye.
13.2 Training loop
Training script uses configurable model size, sequence length, AMP policy, and implementation path (naive/vectorized/triton).
Code: train.py
A practical workflow:
- start with
--ssm-impl vectorizedand fp32 for strict correctness, - switch to
--ssm-impl triton, - enable mixed precision and compare throughput + loss behavior.
During early experiments, log at least:
- training loss,
- validation loss,
- tokens/sec,
- peak memory.
That set is usually enough to detect both optimization regressions and data/label alignment bugs.
Minimal run:
uv run python train.py \
--preset rtx4080 \
--device cuda \
--steps 1000 \
--ssm-impl triton
13.3 Checkpoint and generation
Checkpoint includes model, optimizer, config, vocab, and metadata.
Code: mamba4080/training/checkpointing.py
Generation uses streaming model.step(...) for token-by-token decode:
for tok in prompt_ids:
logits_t, states = model.step(token_t, states=states)
for _ in range(max_new_tokens):
next_id = sample_next_token(logits_t, temperature, top_k)
logits_t, states = model.step(next_id.squeeze(-1), states=states)
Code: generate.py
Generation quality depends heavily on sampling controls. For this repository, temperature and top_k are the two most important knobs for balancing coherence and diversity.
A robust generation sanity check:
- fix seed, prompt, temperature, and top-k,
- run generation twice from the same checkpoint,
- confirm deterministic token stream before exploring creative settings.
14. Correctness methodology in this project
The testing strategy is layered so speedups never silently change model semantics.
14.1 Shape and interface checks
- selective scan shape validation in code
- block/model interface tests
These catch the highest-frequency class of bugs: silent shape drift when modifying projections, state dimensions, or dispatcher signatures.
14.2 Causality checks
Future-token perturbation should not change earlier outputs.
This is a non-negotiable property for language modeling. Every implementation path (naive, vectorized, triton) is expected to preserve this.
Code examples:
14.3 Numerical parity checks
- naive vs vectorized parity
- triton vs vectorized parity
Parity tests are run with fixed seeds and tolerances so regressions are easy to identify.
Parity is checked on both sequence outputs and recurrent end-state when relevant, which helps catch subtle state drift bugs.
14.4 Gradient checks
- autograd finite gradients
- finite-difference checks for selective scan
Finite-difference checks matter because two implementations can agree with each other and still both be wrong.
Code: mamba4080/tests/test_selective_scan_math.py
14.5 Pipeline checks
- short overfit loss decreases
- checkpoint load + deterministic generation
These are integration-level checks. They ensure that data pipeline, model, optimizer, checkpointing, and decoding all work together, not just in isolated unit tests.
Code: mamba4080/tests/test_training_pipeline.py
If you modify scan math or kernel logic, the safest test order is:
- selective scan unit tests,
- block/model causality tests,
- training pipeline smoke tests,
- benchmark suite regression comparison.
15. Experimental results from this project (smoke + practical)
This section summarizes recorded experiments run on an RTX 4080 Laptop GPU.
15.1 Smoke profile summary
Source artifact:
Key points:
- total rows: 10
- successful rows: 10
- triton fallbacks: 0
- tier-A fp32 throughput speedup (
triton/vectorized):- (L=128): (1.279\times)
- (L=256): (1.256\times)
- end-to-end train throughput:
3397.2 tok/s - streaming generation throughput:
480.7 tok/s
Interpretation: smoke profile is intentionally compact and fast. It is best used as a quick regression screen after code changes.
15.2 Practical profile summary
Source artifact:
Key points:
- total rows: 38
- successful rows: 38
- triton fallback rows: 9 (bf16 scan fallback)
- tier-A fp32 throughput speedup (
triton/vectorized):- (L=128): (1.364\times)
- (L=256): (1.143\times)
- (L=512): (1.337\times)
- best tier-B fp32 throughput:
- vectorized:
856.7 tok/satB=4, L=256 - triton:
1366.9 tok/satB=4, L=512
- vectorized:
- tier-C train throughput:
1761.4 tok/s - tier-C generation throughput:
207.4 tok/s
Interpretation: practical profile probes more combinations and exposes fallback behavior, so it is the better source for real tradeoff decisions.
15.3 Interpreting smoke vs practical
Smoke and practical are not directly apples-to-apples. Practical runs larger sweeps and heavier settings (including bf16 entries that may fallback), so absolute throughput can differ even when the underlying kernel speedup is real.
triton_active=true to measure actual Triton kernel impact. Use bf16 rows to assess end-to-end behavior under current fallback policy.
Figure: throughput versus sequence length for smoke and practical suites, with fp32 and bf16 variants shown for vectorized and Triton dispatch.
What to look for:
- fp32 Triton should sit above fp32 vectorized in tier-A scan rows.
- bf16 Triton may track closer to vectorized when scan fallback is active.
- slope changes across sequence length indicate where recurrence cost dominates.
Figure: forward and backward latency breakdown on practical fp32 scan configurations. Backward pass dominates as sequence length grows.
This plot is a reminder that optimizing forward alone is not enough for training throughput. If backward dominates, end-to-end wins depend on backward strategy as much as forward kernel speed.
Figure: peak memory across implementations and dtypes in the practical suite. bf16 Triton fallback points are explicitly marked.
Memory interpretation is often counterintuitive: a faster kernel is not automatically a lower-memory kernel. Always measure both.
Figure: speedup ratio (S(L)=\frac{\text{throughput}{\mathrm{triton}}}{\text{throughput}{\mathrm{vectorized}}}). Values above 1 indicate Triton improvement over vectorized baseline.
This ratio plot is usually the cleanest way to compare revisions because it normalizes away many absolute-runtime differences between runs.
16. End-to-end command checklist
This is a practical runbook for common workflows: environment validation, correctness testing, training, generation, and benchmarking.
16.1 Environment and correctness
uv sync --extra dev --extra triton
uv run python doctor.py
uv run --extra dev pytest -q
If you are actively editing kernels or recurrence math, run this block first after every substantial change.
16.2 Train and generate
uv run python train.py \
--preset rtx4080 \
--device cuda \
--steps 1000 \
--ssm-impl triton
uv run python generate.py \
--checkpoint checkpoints/mamba_tiny.pt \
--prompt "ROMEO:" \
--max-new-tokens 400 \
--temperature 0.9 \
--top-k 40 \
--device cuda
Recommended workflow:
- first run a short training smoke (
--steps 100), - validate generation loads the produced checkpoint,
- then scale steps and sequence length.
16.3 Benchmark and report
uv run python benchmark_suite.py \
--profile practical \
--device cuda \
--out-dir benchmarks/runs/rtx4080_practical
uv run python plot_benchmarks.py \
--input-json benchmarks/results/rtx4080_suite_practical_summary.json \
--out-dir benchmarks/figures \
--title-prefix "Mamba4080 Practical"
uv run python make_benchmark_report.py \
--input-json benchmarks/results/rtx4080_suite_practical_summary.json \
--output-md benchmarks/report_practical.md \
--fig-dir benchmarks/figures \
--title "Mamba4080 Practical Benchmark Report"
uv run python make_tutorial_figures.py \
--smoke-json benchmarks/results/rtx4080_suite_smoke_summary.json \
--practical-json benchmarks/results/rtx4080_suite_practical_summary.json \
--out-dir assets/img/posts/mamba4080_tutorial
Use the same benchmark command templates when comparing code revisions. That keeps apples-to-apples comparisons honest.
17. Appendix A: repository code map (concept-first)
If you want to read the code in a pedagogical order:
- Selective scan math and dispatch:
- Triton optimized path and backward:
- Causal conv and block composition:
- Stacked model:
- Data and training:
- Generation and streaming state:
- Benchmarking:
Reading strategy:
- Start with
selective_scan.pyuntil the recurrence is intuitive. - Move to
mamba_block.pyto see how parameters are generated and used. - Read
selective_scan_triton.pyonly after the reference path is clear. - End with training and benchmarking scripts so the system-level picture clicks.
18. Appendix B: optional deep-dive bridges
Bridge: Discretization nuance and \(\Delta_t\)
You can interpret delta_t as a learned, token-dependent local clock.
When delta_t is larger, the factor exp(delta_t * A) tends to push more aggressive state decay for negative A. When delta_t is smaller, old state is preserved longer. This means the model can dynamically choose memory timescales at token resolution.
That interpretation is useful when debugging:
- If
deltacollapses to tiny values, the model may become overly persistent. - If
deltabecomes too large everywhere, memory may be overwritten too quickly. - Monitoring
deltastatistics during training can reveal these pathologies early.
Bridge: Why both conv and recurrent memory?
Convolution and recurrence solve different parts of sequence modeling:
- Causal depthwise conv captures short-range local patterns very efficiently.
- Recurrent SSM carries long-range information with constant per-token state.
Using both lets the model avoid forcing one mechanism to do everything. Local structure is handled locally; long memory is handled recurrently.
In implementation terms, this split is clean:
- conv updates a short cache (
K-1tokens per channel), - SSM updates a long-memory state (
d_statevalues per channel).
That separation is one reason the streaming API remains manageable.
Bridge: Why include finite-difference tests?
Parity tests tell you two implementations match. Finite-difference checks tell you gradients match local calculus.
You need both:
- parity can fail to catch shared bugs,
- finite-difference is slower but independent of autograd graph structure,
- together they sharply reduce the chance of “fast but wrong” kernel updates.
For this project, finite-difference checks are intentionally small and targeted so they can run routinely without making the test suite painful.
19. Conclusion
Mamba is easiest to internalize if you keep three layers in your head at once:
- Idea layer: dynamic, content-dependent memory in linear-time recurrence.
- Math layer: selective state update with token-conditioned (\Delta_t, B_t, C_t).
- Systems layer: reference path for correctness, vectorized path for baseline speed, Triton path for GPU acceleration with safe fallback.
This tutorial and repository are structured around that stack deliberately:
- the equations remain explicit,
- the code paths remain inspectable,
- the performance path remains test-backed.
If you can read the scan equations, trace MambaBlock end to end, and reason through the Triton forward/backward flow, you now have the core mental model needed to build, verify, and tune a serious Mamba implementation.
References
[1] Albert Gu and Tri Dao. Mamba: Linear-Time Sequence Modeling with Selective State Spaces. arXiv:2312.00752, 2023. https://arxiv.org/abs/2312.00752
[2] Albert Gu, Karan Goel, and Christopher Ré. Efficiently Modeling Long Sequences with Structured State Spaces. ICLR 2022. arXiv:2111.00396. https://arxiv.org/abs/2111.00396
[3] Albert Gu, Tri Dao, Stefano Ermon, Atri Rudra, and Christopher Ré. HiPPO: Recurrent Memory with Optimal Polynomial Projections. NeurIPS 2020. arXiv:2008.07669. https://arxiv.org/abs/2008.07669
[4] Philippe Tillet, H. T. Kung, and David Cox. Triton: An Intermediate Language and Compiler for Tiled Neural Network Computations. MLSys 2019. https://github.com/triton-lang/triton
[5] Official Triton language and compiler documentation. https://triton-lang.org