Mamba from First Principles

Selective State Spaces, Parallel Scan, and a PyTorch Implementation

** Post is under construction **

1. Introduction

Sequence modeling has long been shaped by a basic tension. We want models that respond sharply to token content, yet we also want training and inference to remain efficient as context length grows. Modern transformers solve the first problem extremely well. Their central operation is explicitly content-dependent. Each token can decide which previous tokens matter, and by how much. That flexibility is one of the main reasons attention-based models have become the dominant architecture for language.

The cost of that flexibility is equally familiar. Full attention scales quadratically with sequence length in both compute and memory. This is not merely a theoretical inconvenience. In practical systems, long context quickly becomes a hardware problem. Memory traffic rises, latency rises, and the cost of caching large histories during autoregressive generation becomes a major systems constraint.

Recurrent models sit at the opposite end of the design space. They process a sequence one step at a time and maintain a finite hidden state. This gives them a natural linear-time causal structure and makes streaming inference conceptually simple. The difficulty is that classical recurrent models compress the entire past through a fixed update rule. That fixed rule is often too rigid for language-like data, where some tokens should be preserved over long spans while many others should be forgotten almost immediately.

Mamba is interesting because it addresses exactly this limitation. It keeps the recurrent state-space backbone, but it does not keep a fixed recurrence. Instead, the state update changes with the input token. The model still carries a finite recurrent state forward in time, but the policy that governs how that state is updated becomes content-dependent.

This is the central idea of the tutorial:

Mamba is a recurrent model whose memory update rule is generated from the current token.

That sentence captures what is distinctive about the architecture. Mamba does not recover content dependence by revisiting the entire past through attention. It recovers content dependence by making the recurrence itself adaptive.

This tutorial develops that idea from first principles. The goal is not to survey the literature or provide a benchmark report. The goal is narrower and more useful: to build a precise mental model of the selective recurrence at the core of Mamba, and then connect that mental model to a clean PyTorch implementation.

By the end, you should be able to do six things with confidence:

  1. explain why fixed recurrence is too blunt for language-like data
  2. derive the selective state update used by Mamba
  3. interpret the roles of $A$, $\Delta_t$, $B_t$, $C_t$, and the skip path
  4. understand how a Mamba block generates a token-dependent memory policy
  5. implement the reference recurrence in PyTorch with correct tensor shapes
  6. see why scan-style parallelization becomes necessary once selection breaks time invariance

Kernel-level optimization is best treated separately. Here, the focus is conceptual clarity. If the recurrence is clear, the block structure becomes clear. If the block structure is clear, the implementation becomes straightforward. If the implementation is clear, later systems questions have a solid foundation.

[Placeholder Figure 1]

The right way to read Mamba is through three layers at once.

The first is the dynamical systems layer. Mamba inherits the language of state space models, where hidden state evolves according to structured dynamics and produces an output through a learned readout.

The second is the sequence modeling layer. Mamba uses that state not as a generic latent vector, but as a learned memory that is updated token by token. The model decides, at each token, what should decay, what should be written, and how the resulting state should be read out.

The third is the algorithmic layer. Once the update rule becomes token-dependent, the model is no longer time-invariant. That removes the simple convolutional trick that made earlier state space models especially efficient during training. The recurrence remains causal and finite-state, but it now requires a scan-oriented computational strategy.

This post concentrates on the first two layers and introduces the third only at the level needed to understand the architecture. The guiding principle throughout is simple: every equation should answer a modeling question, and every code block should reveal an idea.

Scope

This post focuses on the mechanism of Mamba itself:

  • why selection is needed
  • how the selective recurrence is constructed
  • how the block produces the recurrence parameters
  • how to implement the reference version in PyTorch
  • why scan enters the picture

It does not attempt to cover:

  • kernel engineering details
  • Triton implementation details
  • benchmark methodology
  • repository walkthroughs
  • later variants such as Mamba-2

Those topics are better handled separately.


2. The problem Mamba is actually solving

A recurrent model is always a compression mechanism.

This is true whether the model is a classical RNN, an LSTM, a GRU, or a structured state space model. At each time step, the model receives a new token and updates a finite hidden state. That hidden state must summarize everything from the past that may matter in the future. Once viewed from this angle, the central question is not whether recurrence compresses information. It must. The central question is whether the compression policy is expressive enough to preserve the right information and discard the rest.

Language makes this difficult.

Some tokens carry information that should persist for a long time. A speaker name introduced early in a dialogue can matter many sentences later. A variable name in code may need to remain active across a long intervening span. A negation can alter the interpretation of later words. At the same time, many tokens are locally useful but globally disposable. Punctuation, routine fillers, and short-lived syntactic cues often matter only briefly.

A fixed recurrence must process all of these cases through the same update rule.

That is the key limitation.

To see the issue more sharply, consider a toy sequence task. Suppose the input contains many irrelevant filler tokens, punctuated occasionally by a marked token that must be remembered until a later query token appears. A model that forgets too aggressively will lose the marked token before it is needed. A model that preserves information too aggressively will keep accumulating irrelevant filler and pollute its state. In both cases, the problem is not that the model lacks recurrence. The problem is that it applies one global memory policy to every token.

That is too blunt an instrument for information-dense sequences.

A fixed linear recurrence has no direct way to say:

  • this token should be written strongly
  • this token should be mostly ignored
  • this token should trigger rapid decay of stale state
  • this token should preserve what is already present

Yet these are exactly the distinctions that language models need to make.

This is the point at which the language of selection becomes useful. In the context of Mamba, selection does not mean sparse attention over the past. It does not mean routing tokens to experts. It does not mean explicitly storing some tokens and deleting others from a cache. Selection means something more local and more structural:

the coefficients of the state update become functions of the current token features

That change is subtle, but decisive.

In a fixed recurrence, the hidden state evolves under one stationary rule. In a selective recurrence, the rule itself changes from token to token. The state remains finite. The model still processes the sequence causally. But the update is no longer token-agnostic. The model can now learn to preserve rare, high-value information over long spans while quickly flushing irrelevant content.

This is the compression view of Mamba.

It is tempting to describe Mamba as a hybrid between recurrent models and attention. That language is not entirely wrong, but it can be misleading. The more precise statement is that Mamba preserves the finite-state recurrent backbone of an SSM while introducing token-conditioned control over the update. The architecture does not recover content dependence by revisiting the entire past. It recovers content dependence by changing how the past is compressed into state.

That distinction matters because it determines both the strengths and the computational consequences of the model. The strength is clear: memory becomes adaptive. The consequence is equally important: once the recurrence depends on token content, the model is no longer time-invariant, and this changes how it must be implemented efficiently.

Before turning to that computational story, it is worth making the modeling problem fully concrete.

Imagine a single scalar hidden state $s_t$ updated by

\[s_t = a s_{t-1} + b u_t.\]

Here $u_t$ is the input signal at time $t$, while $a$ and $b$ are fixed coefficients shared across the whole sequence.

This recurrence faces an impossible tradeoff on the selective-copy toy task. If the magnitude of $a$ is small, then the state quickly forgets old information. That helps remove filler tokens, but it also destroys long-range memory. If the magnitude of $a$ is close to one, then the state preserves information for longer, but it also preserves irrelevant content. There is no single fixed choice of $a$ and $b$ that solves both problems cleanly across all tokens.

The only way out is to let the update depend on the current token.

That is the step Mamba takes.

[Placeholder Figure 2]

A good mental picture is the following. Every recurrent model must answer three questions at each token:

  1. how much of the current state should survive
  2. how much of the current input should be written into memory
  3. how should the resulting state be read out into the output stream

A fixed recurrence answers all three questions with the same rule at every time step. Mamba answers them dynamically.

That is why selection is the core idea, not a secondary design detail.

Why synthetic tasks reveal the issue so clearly

Toy tasks such as selective copying are useful because they isolate the compression problem. They remove most of the confounding factors present in natural language and force the model to do one thing well: preserve a sparse, content-dependent signal over a long interval while ignoring a large amount of distractor content.

A fixed linear time-invariant recurrence struggles here for structural reasons. If the update rule is unchanged across time, then the model has no direct mechanism for assigning different memory policies to different tokens. It can only hope that a single global timescale and a single write rule will be adequate for every event in the sequence.

That is rarely what language demands. Language alternates constantly between local structure, medium-range dependencies, and sparse long-range anchors. A useful memory system must therefore be able to change its behavior token by token.

This is the conceptual gap that Mamba closes.


3. State space models from first principles

To understand Mamba cleanly, it helps to begin with the standard language of state space models.

A continuous-time linear state space model is written as

\[\dot{x}(t) = A x(t) + B u(t), \qquad y(t) = C x(t) + D u(t).\]

Each term has a direct interpretation.

  • $x(t)$ is the internal state
  • $u(t)$ is the input signal
  • $y(t)$ is the output signal
  • $A$ governs how state evolves even in the absence of new input
  • $B$ determines how input is written into state
  • $C$ determines how state is read into output
  • $D$ provides a direct path from input to output

This formulation is valuable because it separates memory dynamics from observation. The state is the memory. The matrix $A$ describes how that memory changes over time. The pair $B, C$ determines how the external world interacts with the memory.

For sequence modeling, that perspective is exactly what we want. We are not interested in state space models for their own sake. We are interested in them because they provide a structured way to think about recurrent memory.

The continuous-time form is not yet the model used in a language network, because language arrives as discrete tokens rather than as a continuously sampled signal. The next step is therefore to discretize the dynamics.

At a high level, discretization turns the continuous system into a recurrence of the form

\[x_t = \bar{A} x_{t-1} + \bar{B} u_t, \qquad y_t = C x_t + D u_t.\]

Now the hidden state is indexed by token position $t$. The matrix $\bar{A}$ tells us how much of the previous state survives from one token to the next. The matrix $\bar{B}$ determines how the current token is written into state. The output rule is unchanged in spirit. We read from state through $C$ and combine that with a direct skip term $D u_t$.

This is already enough to see the connection between SSMs and recurrent sequence models. The model maintains a state. Each new token updates that state. The output depends on the updated state. The recurrence is causal by construction.

The question is what structure to place on the recurrence so that it is expressive, stable, and computationally tractable.

In a completely generic recurrent model, the state update could involve a dense matrix multiplication across all hidden coordinates at every step. That is flexible, but it is not a particularly enlightening way to think about Mamba. The implementation-friendly viewpoint is more structured.

Mamba is easiest to understand channelwise.

Suppose the model has an inner width $D$, and each channel carries a small state vector of dimension $N$. Then the hidden state at token $t$ can be written as

\[x_t \in \mathbb{R}^{B \times D \times N},\]

where $B$ is batch size. The input stream is

\[u \in \mathbb{R}^{B \times L \times D},\]

where $L$ is sequence length.

In this view, each channel has access to a small collection of state modes. Those modes can be thought of as memory traces with different characteristic timescales. Some decay quickly. Others decay slowly. The output of the channel is then formed by reading from those modes.

This is the first useful picture to keep in mind: an SSM channel is not a single scalar memory. It is a small bank of state coordinates, each of which can preserve information differently over time.

[Placeholder Figure 3]

That picture becomes more concrete once we choose a practical discrete update. Let $A \in \mathbb{R}^{D \times N}$ denote the channelwise state dynamics. Then one implementation-friendly form of the recurrence is

\[x_t[b,d,n] = \bar{A}[d,n] \, x_{t-1}[b,d,n] + \bar{B}[d,n] \, u_t[b,d].\]

The corresponding channel output is

\[y_t[b,d] = \sum_{n=1}^{N} C[d,n] \, x_t[b,d,n] + D[d] \, u_t[b,d].\]

This is already revealing. The state update is cheap because it is elementwise over the state dimension once the coefficients have been chosen. The input for channel $d$ writes into each of the $N$ state coordinates for that channel. The readout sums those coordinates back into one output value for the channel.

If the coefficients $\bar{A}$, $\bar{B}$, and $C$ are fixed across time, then the model is a structured recurrent system with fixed dynamics. This is the starting point for earlier state space sequence models. It gives the model a principled memory mechanism and allows multiple timescales to coexist within each channel. What it does not yet provide is token-dependent control over the update.

That is the step from a fixed SSM to Mamba.

A slightly more formal view of discretization

The continuous-time system

\[\dot{x}(t) = A x(t) + B u(t)\]

becomes a discrete recurrence once we choose how the system advances between successive tokens. In the simplest conceptual picture, the discrete transition depends on a step size $\Delta$. A common way to think about the resulting update is

\[x_t = \exp(\Delta A) x_{t-1} + \text{input term}.\]

The matrix exponential here is not a decorative detail. It tells us how the continuous dynamics generated by $A$ propagate over one discrete step of size $\Delta$. If the modes of $A$ are negative, then larger values of $\Delta$ produce stronger decay over that step.

This becomes central in Mamba, where the step size is no longer fixed. Instead, the model learns a token-dependent local clock $\Delta_t$. That turns the discretization itself into part of the model’s adaptive memory policy.


4. From a discrete SSM to the selective recurrence

We can now state the precise move that turns a structured recurrent SSM into Mamba.

A fixed SSM uses the same transition and input maps at every time step. Mamba does not. It allows key parts of the recurrence to depend on the current token. In practice, the most important token-dependent quantities are:

  • the step size $\Delta_t$
  • the write map $B_t$
  • the read map $C_t$

The state dynamics $A$ remain shared across time within a layer. This is an important design choice. The model does not regenerate the entire dynamical system from scratch at every token. Instead, it keeps a stable set of available state modes and modulates how those modes are used from token to token.

The first ingredient is the token-dependent step size.

For each token $t$, and for each channel $d$, Mamba produces a value $\Delta_t[b,d]$. It is useful to interpret $\Delta_t$ as a learned local clock. If $\Delta_t$ is small, then the internal dynamics advance only a little at that token. If $\Delta_t$ is large, then the internal dynamics advance more aggressively. When the modes of $A$ are negative, this means that larger $\Delta_t$ induces stronger forgetting.

Using this local clock, the channelwise discrete transition is written as

\[\bar{A}_t[b,d,n] = \exp\!\left(\Delta_t[b,d] \, A[d,n]\right).\]

This equation is one of the conceptual anchors of the architecture.

It says that the amount of state preserved from one token to the next is no longer fixed. It depends on the current token through $\Delta_t$. Because $A[d,n]$ is typically negative in practice, the quantity $\exp(\Delta_t A)$ acts like a decay factor. Small $\Delta_t$ preserves more of the past. Large $\Delta_t$ decays the old state more strongly.

The second ingredient is token-dependent writing.

At token $t$, the model also produces a write map

\[B_t \in \mathbb{R}^{B \times D \times N}.\]

This map determines how the current input $u_t \in \mathbb{R}^{B \times D}$ is written into the state coordinates. Together with the local clock, it gives the write term

\[\Delta_t[b,d] \, B_t[b,d,n] \, u_t[b,d].\]

The state update therefore becomes

\[x_t[b,d,n] = \bar{A}_t[b,d,n] \, x_{t-1}[b,d,n] + \Delta_t[b,d] \, B_t[b,d,n] \, u_t[b,d].\]

This equation is worth unpacking carefully.

The first term preserves part of the old state. The preservation factor is token-dependent. The second term writes part of the current input into the state. The write strength is also token-dependent. The same token can therefore both control forgetting and control new memory formation.

The third ingredient is token-dependent reading.

At token $t$, the model produces a read map

\[C_t \in \mathbb{R}^{B \times D \times N}.\]

The output for each channel is then

\[y_t[b,d] = \sum_{n=1}^{N} C_t[b,d,n] \, x_t[b,d,n] + D[d] \, u_t[b,d].\]

Here $D \in \mathbb{R}^{D}$ is a learned skip coefficient for each channel. It provides a direct path from the current input to the output, so that not every useful signal must be stored in recurrent state before it can influence the block output.

Putting the three pieces together, the selective recurrence can be written compactly as

\[\bar{A}_t = \exp\!\left(\Delta_t \odot A\right),\] \[x_t = \bar{A}_t \odot x_{t-1} + \left(\Delta_t \odot B_t\right) \odot u_t,\] \[y_t = \langle C_t, x_t \rangle + D \odot u_t.\]

All products above are understood with the natural broadcasting over batch, channel, and state dimensions.

This is the recurrence that matters.

Everything else in a Mamba block exists to produce the quantities that appear in these three lines.

It is useful to read the recurrence as four distinct actions carried out at every token:

  1. preserve part of the old state through $\bar{A}t \odot x{t-1}$
  2. write new content through $\left(\Delta_t \odot B_t\right) \odot u_t$
  3. read the updated state through $C_t$
  4. skip directly from input to output through $D \odot u_t$

That decomposition is the cleanest operational understanding of Mamba.

It also makes the meaning of selection precise.

A token that should mostly be ignored can produce a weak write map and a read map that de-emphasizes the resulting state. A token that should preserve long-range memory can produce a small $\Delta_t$, which slows the effective decay. A token that should flush stale information can produce a larger $\Delta_t$, causing the old state to decay more strongly. A token that should sharply alter the current output can change $C_t$, even if the underlying state is similar.

Selection is therefore not a single gate. It is a token-dependent control policy over the recurrence.

[Placeholder Figure 4]

The corresponding one-step implementation is short enough to write directly:

import torch


def selective_step(
    state: torch.Tensor,
    u_t: torch.Tensor,
    delta_t: torch.Tensor,
    A: torch.Tensor,
    B_t: torch.Tensor,
    C_t: torch.Tensor,
    D_skip: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    One selective SSM update.

    Shapes:
        state:   B x D x N
        u_t:     B x D
        delta_t: B x D
        A:       D x N
        B_t:     B x D x N
        C_t:     B x D x N
        D_skip:  D

    Returns:
        y_t:     B x D
        state:   B x D x N
    """
    delta_t_exp = delta_t.unsqueeze(-1)             # B x D x 1
    u_t_exp = u_t.unsqueeze(-1)                     # B x D x 1
    A_expanded = A.unsqueeze(0)                     # 1 x D x N

    A_bar_t = torch.exp(delta_t_exp * A_expanded)   # B x D x N
    write_t = delta_t_exp * B_t * u_t_exp           # B x D x N

    state = A_bar_t * state + write_t               # B x D x N
    y_t = torch.sum(C_t * state, dim=-1)            # B x D
    y_t = y_t + D_skip.unsqueeze(0) * u_t           # B x D
    return y_t, state

This code is almost a direct transcription of the recurrence. It is a useful checkpoint because it makes the tensor geometry explicit before any block structure is added.

Causality is immediate. At token $t$, the model uses only:

  • the previous state $x_{t-1}$
  • the current input $u_t$
  • the current token-conditioned coefficients $\Delta_t$, $B_t$, and $C_t$

There is no dependence on future tokens. The model is causal by construction. This is why autoregressive step-by-step generation fits so naturally into the architecture. The same recurrence used during training can be run one token at a time at inference, provided we carry the state forward.

At this point, the central mechanism of Mamba is already in view. The remaining questions are architectural and computational:

  • Where do $\Delta_t$, $B_t$, and $C_t$ come from
  • How are they generated from token features inside a block
  • How do we implement the recurrence cleanly in PyTorch
  • Why does the loss of time invariance force us to think about scan

Those are the questions taken up next.

Local derivatives of the selective recurrence

For readers who want the calculus in full, the selective recurrence is simple enough that all local derivatives can be written down explicitly.

Start from

\[x_t = \bar{A}_t \odot x_{t-1} + (\Delta_t \odot B_t) \odot u_t, \qquad \bar{A}_t = \exp(\Delta_t \odot A),\]

and

\[y_t = \langle C_t, x_t \rangle + D \odot u_t.\]

Then the local derivatives at one token satisfy

\[\frac{\partial y_t}{\partial C_t} = x_t, \qquad \frac{\partial y_t}{\partial x_t} = C_t.\]

For the state update, the derivative with respect to the previous state is

\[\frac{\partial x_t}{\partial x_{t-1}} = \bar{A}_t.\]

The derivative with respect to the write map is

\[\frac{\partial x_t}{\partial B_t} = \Delta_t \odot u_t.\]

The derivative with respect to the input is

\[\frac{\partial x_t}{\partial u_t} = \Delta_t \odot B_t.\]

The derivative with respect to $A$ follows from the exponential term:

\[\frac{\partial x_t}{\partial A} = \Delta_t \odot \exp(\Delta_t \odot A) \odot x_{t-1}.\]

The derivative with respect to $\Delta_t$ has two contributions, one through the decay term and one through the write term:

\[\frac{\partial x_t}{\partial \Delta_t} = A \odot \exp(\Delta_t \odot A) \odot x_{t-1} + B_t \odot u_t.\]

These local expressions are enough to derive a reverse-time backward pass for the scan. Conceptually, nothing mysterious is happening. The backward computation is another recurrence, this time over gradients flowing from future tokens to earlier ones.


5. Interpreting the parameters as a memory policy

At this point, the selective recurrence is on the table. The remaining task is to interpret it correctly.

The equations introduce five learned objects that matter operationally:

  • $A$, which defines the available state dynamics
  • $\Delta_t$, which controls how aggressively those dynamics advance at token $t$
  • $B_t$, which determines how the current token is written into state
  • $C_t$, which determines how the state is read into the output at token $t$
  • $D$, which provides a direct path from the current token to the output

It is tempting to view these as just another collection of learned coefficients. That is technically true, but it misses the deeper structure. A more useful perspective is to view them as a memory policy. The Mamba block is not merely computing an output. At each token, it is deciding how memory should behave.

5.1 $A$ defines the available timescales

The matrix $A$ is the least token-specific part of the recurrence. It is shared across time within a layer, and it defines the dynamical modes that the layer has access to. In the channelwise form used for implementation, $A \in \mathbb{R}^{D \times N}$. For each channel $d$, the vector $A[d, :]$ determines how the $N$ state coordinates in that channel evolve in the absence of new input.

The key idea is that $A$ does not tell the model what to remember at a particular token. Instead, it defines the space of possible forgetting behaviors that the layer can express. It is the background dynamical substrate of the memory system.

A useful way to think about this is that each channel contains a small collection of decay modes. Some modes are effectively short-memory. Others are effectively long-memory. The layer does not reinvent these modes token by token. It learns them once, then modulates their use through the token-dependent quantities.

That design is both expressive and disciplined. It gives the model a stable family of possible timescales while avoiding the cost and instability that would come from regenerating the entire dynamical system at every step.

5.2 $\Delta_t$ chooses how quickly the internal clock advances

If $A$ defines the available timescales, then $\Delta_t$ selects how those timescales are used at token $t$.

Recall the transition factor

\[\bar{A}_t = \exp(\Delta_t \odot A).\]

This is the term that determines how much of the previous state survives. Since $A$ is typically parameterized to be negative in practice, the exponential acts like a decay factor. Larger values of $\Delta_t$ produce stronger decay. Smaller values of $\Delta_t$ preserve more of the past.

This is why $\Delta_t$ is best understood as a learned local clock. It controls how far the internal dynamics move forward at the current token.

That interpretation is operational. If a token signals that old state should be flushed quickly, the model can produce a larger $\Delta_t$. If a token suggests that the existing state should remain active, the model can produce a smaller $\Delta_t$. Different channels can do this differently at the same token, which means that the model can preserve some features while rapidly discarding others.

This is the first place where the recurrence becomes genuinely adaptive.

5.3 $B_t$ determines what is written into memory

The write map $B_t$ controls how the current token enters the state. In the channelwise formulation, $B_t \in \mathbb{R}^{B \times D \times N}$. It tells the model how the scalar input signal for channel $d$ should be distributed across the $N$ state coordinates of that channel.

The write term is

\[(\Delta_t \odot B_t) \odot u_t.\]

The role of $B_t$ is therefore not simply to scale the input. It shapes the geometry of the write into state. Two tokens with similar magnitudes in the input stream can still have very different effects on memory if they produce different write maps.

This is a crucial point. In a fixed SSM, the input is always written into memory through the same map. In Mamba, the write map depends on the current token features. This means that the model can decide that some tokens should write strongly into slow modes, some should write into fast modes, and some should barely alter the state at all.

Selection is already visible here. A token that carries little lasting value need not be written strongly into memory. A token that should leave a durable trace can be routed into state coordinates that decay more slowly.

5.4 $C_t$ determines how memory is read

If $B_t$ is the token-dependent write rule, then $C_t$ is the token-dependent read rule.

The output equation is

\[y_t[b,d] = \sum_{n=1}^{N} C_t[b,d,n] \, x_t[b,d,n] + D[d] \, u_t[b,d].\]

Here the state has already been formed. The question is how that state should influence the output at the current token. The read map $C_t$ answers that question.

This is more subtle than it may first appear. The same underlying state can support different outputs depending on how it is read. In other words, token dependence is not limited to writing. Mamba also allows the model to change the readout of memory from token to token.

That flexibility matters because memory is not useful in the abstract. It is useful only insofar as the model can expose the right part of it at the right moment. A token that should trigger a strong use of previously stored information can produce a read map that emphasizes the relevant state coordinates. A token that should rely more on local information can downweight recurrent readout.

This gives the block a second axis of selectivity. It can decide not only what should enter memory, but also how the current memory should be interpreted.

5.5 $D$ preserves a direct path around recurrence

The skip term $D \odot u_t$ is easy to underestimate, but it plays an important role.

Not every useful feature needs to be routed through long-range state. Some information is local and should influence the output immediately. The skip path allows this. It gives the block a short route from input to output without forcing every signal through the recurrent state.

This matters for both optimization and representation. Optimization benefits because the block need not solve every task by manipulating long-memory state. Representation benefits because local and recurrent information can coexist naturally. The model can use the recurrent state for durable structure while still preserving a direct token-level contribution.

5.6 A compact mental model

The most concise way to summarize the selective recurrence is this:

  • $A$ defines the repertoire of memory timescales
  • $\Delta_t$ decides how quickly those timescales advance at token $t$
  • $B_t$ decides what is written into memory
  • $C_t$ decides how memory is read into the output
  • $D$ preserves a direct path for local information

This is the memory-policy view of Mamba.

It is a better mental model than treating the recurrence as a bag of coefficients, because it immediately tells you what to inspect when behavior goes wrong. If the model forgets too quickly, look at $\Delta_t$ and the effective decays induced by $A$. If it fails to preserve important events, inspect what $B_t$ is writing and where that content is being stored. If the state looks rich but the output remains uninformative, inspect the read map $C_t$.

[Placeholder Figure 5]

How this differs from classical gating in recurrent networks

It is natural to compare Mamba to gated recurrent architectures such as the LSTM or GRU. The comparison is useful, but only if stated carefully.

In an LSTM or GRU, the recurrence contains a small number of named gates with a fixed structural role. One gate decides how much of the previous state to retain. Another decides how much of a candidate update to admit. The architecture is explicitly built around that gating template.

Mamba is different in two ways.

First, its state update is expressed through the language of discretized state space dynamics rather than through a fixed gate template. The model is not merely scaling a hidden vector coordinatewise. It is modulating a structured dynamical system.

Second, the token-dependent quantities in Mamba do more than gate a candidate state. They alter the effective transition, the write rule, and the read rule. In that sense, Mamba is better understood as a token-conditioned controller over recurrent memory than as a standard gated RNN.

This distinction becomes especially important when thinking about timescales. In Mamba, the local clock $\Delta_t$ has a direct interpretation in terms of how quickly the internal dynamics advance. That is a more dynamical view of memory control than the usual gate-based picture.


6. The smallest useful Mamba block

The selective recurrence is the core mechanism, but it is not yet a full neural network block. We still need to answer a practical question:

Where do the token-dependent quantities $\Delta_t$, $B_t$, and $C_t$ come from?

A Mamba block answers that question by generating them from the input sequence itself. The block takes a token representation, processes it through a small amount of local mixing and projection, and uses the result to parameterize the recurrence.

The goal of this section is not to reproduce every detail of a production implementation. The goal is to isolate the smallest useful block that makes the architecture intelligible.

6.1 The ingredients of the block

A clean pedagogical Mamba block contains the following components:

  1. a normalization layer on the input stream
  2. an input projection that splits the features into two streams
  3. a short causal depthwise convolution on the recurrent stream
  4. a parameter head that produces $\Delta$, $B$, and $C$
  5. the selective scan itself
  6. an output gate driven by the second stream
  7. a projection back to the model dimension
  8. a residual connection

This is already enough to expose the architecture’s logic. The block takes token features, forms a content-dependent memory policy, runs the selective recurrence, and then filters the result before returning it to the residual stream.

6.2 Two streams, two roles

The first major design choice is the split into two streams. After normalization, the input is projected into an $x$-stream and a $z$-stream.

The $x$-stream is the recurrent stream. It carries the features that will be locally mixed, transformed into recurrence parameters, and then scanned through state.

The $z$-stream is the output-gating stream. It does not directly define the state update. Instead, it modulates what leaves the recurrent path and returns to the residual pathway.

This division of labor is simple but effective. The recurrent stream is responsible for constructing and using memory. The gate stream is responsible for deciding how much of that recurrent output should be exposed downstream.

That makes the block easier to interpret. One stream builds the memory update. The other stream filters the contribution of memory to the final output.

6.3 Why local mixing appears before the recurrence

Before generating $\Delta$, $B$, and $C$, the recurrent stream is typically passed through a short causal depthwise convolution.

This is not an arbitrary flourish. It serves a clear purpose.

If the parameter head looked only at the current token in isolation, then the memory policy at token $t$ would be determined entirely by that one token representation. In practice, short-range context often matters. The meaning of a token can depend strongly on a few nearby tokens. A short causal convolution gives the block access to this local context before it decides how memory should update.

The convolution is depthwise, which means each channel is convolved independently. That keeps the operation lightweight. It is causal, which means it does not leak future information. And it is local, which means it injects a small receptive field without trying to replace the long-range recurrent memory.

This is an important division of labor. The convolution handles short-range pattern formation. The selective recurrence handles long-range memory.

6.4 End-to-end structure of the block

Let the block input be

\[h \in \mathbb{R}^{B \times L \times D_{\text{model}}}.\]

After normalization, we project into an inner width $D_{\text{inner}}$ and split into two streams:

\[[x, z] = \operatorname{split}\!\left(W_{\text{in}} \, \mathrm{Norm}(h)\right), \qquad x, z \in \mathbb{R}^{B \times L \times D_{\text{inner}}}.\]

The recurrent stream is passed through a short causal depthwise convolution and nonlinearity:

\[x' = \mathrm{SiLU}\!\left(\mathrm{DWConv}_{\text{causal}}(x)\right).\]

From the convolved features $x’$, the block predicts the selective parameters:

\[(\Delta, B, C) = f_{\text{param}}(x').\]

These are then fed into the selective scan:

\[y = \mathrm{SelectiveScan}(x', \Delta, A, B, C, D_{\text{skip}}).\]

The output is gated by the second stream:

\[\hat{y} = y \odot \mathrm{SiLU}(z).\]

Finally, the result is projected back to the model width and added to the residual stream:

\[\mathrm{out} = W_{\text{out}} \hat{y} + h.\]

This is the smallest useful Mamba block.

Everything in it has a clear role:

  • normalization stabilizes the input statistics
  • the input projection creates a recurrent path and a gate path
  • the causal convolution provides local context
  • the parameter head converts local features into a memory policy
  • the selective scan performs the long-range recurrent update
  • the gate controls the block output
  • the output projection returns to the model width
  • the residual connection preserves the standard deep network interface

6.5 What this block is really doing

It is useful to describe the entire block in one sentence:

The block uses local context to decide how memory should update, then runs that memory update recurrently across the sequence, and finally gates how much of the recurrent result should be exposed.

That description is more informative than simply listing the layers, because it identifies the functional logic of the block.

The block is not just another sequence module with some projections and a convolution. It is a learned controller over recurrent memory.

6.6 A pedagogical implementation

The following implementation is intentionally compact. It shows the full block with all important tensors exposed. It relies on helper classes defined in this post and on the reference scan introduced later.

import torch
import torch.nn as nn
import torch.nn.functional as F


class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5) -> None:
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Shapes:
            x: ... x D_model

        Returns:
            out: ... x D_model
        """
        rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
        x_norm = x / rms
        return x_norm * self.weight


class DepthwiseCausalConv1d(nn.Module):
    def __init__(self, d_inner: int, kernel_size: int) -> None:
        super().__init__()
        self.d_inner = d_inner
        self.kernel_size = kernel_size

        self.weight = nn.Parameter(torch.empty(d_inner, kernel_size))
        self.bias = nn.Parameter(torch.zeros(d_inner))
        nn.init.normal_(self.weight, mean=0.0, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Shapes:
            x:   B x L x D
            out: B x L x D
        """
        B_size, L, D_inner = x.shape
        x_channels_first = x.transpose(1, 2)                        # B x D x L
        x_padded = F.pad(x_channels_first, (self.kernel_size - 1, 0))
        weight = self.weight.unsqueeze(1)                           # D x 1 x K

        y = F.conv1d(
            x_padded,
            weight=weight,
            bias=self.bias,
            stride=1,
            padding=0,
            groups=self.d_inner,
        )                                                           # B x D x L
        return y.transpose(1, 2)                                    # B x L x D

    def init_state(self, batch_size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
        """
        Returns:
            conv_state: B x D x K_minus_1
        """
        return torch.zeros(
            batch_size,
            self.d_inner,
            self.kernel_size - 1,
            device=device,
            dtype=dtype,
        )

    def step(self, x_t: torch.Tensor, conv_state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        One causal depthwise conv step.

        Shapes:
            x_t:        B x D
            conv_state: B x D x K_minus_1

        Returns:
            y_t:        B x D
            new_state:  B x D x K_minus_1
        """
        x_t_exp = x_t.unsqueeze(-1)                                  # B x D x 1
        window = torch.cat([conv_state, x_t_exp], dim=-1)           # B x D x K

        y_t = torch.sum(window * self.weight.unsqueeze(0), dim=-1)  # B x D
        y_t = y_t + self.bias.unsqueeze(0)                          # B x D

        new_state = window[:, :, 1:]                                # B x D x K_minus_1
        return y_t, new_state


class MinimalMambaBlock(nn.Module):
    def __init__(
        self,
        d_model: int,
        d_inner: int,
        d_state: int,
        d_conv: int,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.d_inner = d_inner
        self.d_state = d_state
        self.d_conv = d_conv

        self.norm = RMSNorm(d_model)
        self.in_proj = nn.Linear(d_model, 2 * d_inner, bias=False)
        self.conv = DepthwiseCausalConv1d(d_inner=d_inner, kernel_size=d_conv)

        self.param_proj = nn.Linear(
            d_inner,
            d_inner + 2 * d_inner * d_state,
            bias=False,
        )

        self.A_log = nn.Parameter(torch.zeros(d_inner, d_state))
        self.D_skip = nn.Parameter(torch.ones(d_inner))
        self.delta_bias = nn.Parameter(torch.zeros(d_inner))

        self.out_proj = nn.Linear(d_inner, d_model, bias=False)

    def _project_selective_params(
        self,
        x_conv: torch.Tensor,
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Shapes:
            x_conv: B x L x D_inner

        Returns:
            delta: B x L x D_inner
            B_map: B x L x D_inner x N
            C_map: B x L x D_inner x N
        """
        B_size, L, _ = x_conv.shape

        params = self.param_proj(x_conv)                               # B x L x [D + 2DN]
        delta_raw = params[:, :, : self.d_inner]                       # B x L x D

        BC_raw = params[:, :, self.d_inner :]                          # B x L x 2DN
        split_size = self.d_inner * self.d_state
        B_raw = BC_raw[:, :, :split_size]                              # B x L x DN
        C_raw = BC_raw[:, :, split_size:]                              # B x L x DN

        B_map = B_raw.view(B_size, L, self.d_inner, self.d_state)     # B x L x D x N
        C_map = C_raw.view(B_size, L, self.d_inner, self.d_state)     # B x L x D x N

        delta = F.softplus(delta_raw + self.delta_bias.view(1, 1, -1))
        return delta, B_map, C_map

    def init_state(
        self,
        batch_size: int,
        device: torch.device,
        dtype: torch.dtype,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Returns:
            ssm_state:  B x D_inner x N
            conv_state: B x D_inner x K_minus_1
        """
        ssm_state = torch.zeros(
            batch_size,
            self.d_inner,
            self.d_state,
            device=device,
            dtype=dtype,
        )
        conv_state = self.conv.init_state(batch_size, device=device, dtype=dtype)
        return ssm_state, conv_state

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Shapes:
            x:   B x L x D_model
            out: B x L x D_model
        """
        residual = x                                                   # B x L x D_model
        x_norm = self.norm(x)                                          # B x L x D_model

        xz = self.in_proj(x_norm)                                      # B x L x 2D_inner
        x_stream, z_stream = torch.split(xz, self.d_inner, dim=-1)     # each B x L x D_inner

        x_conv = self.conv(x_stream)                                   # B x L x D_inner
        x_conv = F.silu(x_conv)                                        # B x L x D_inner

        delta, B_map, C_map = self._project_selective_params(x_conv)
        A = -torch.exp(self.A_log)                                     # D_inner x N

        y = selective_scan_reference(
            u=x_conv,
            delta=delta,
            A=A,
            B_map=B_map,
            C_map=C_map,
            D_skip=self.D_skip,
        )                                                              # B x L x D_inner

        y = y * F.silu(z_stream)                                       # B x L x D_inner
        out = self.out_proj(y) + residual                              # B x L x D_model
        return out

    def step(
        self,
        x_t: torch.Tensor,
        state: tuple[torch.Tensor, torch.Tensor],
    ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
        """
        One-token block update.

        Shapes:
            x_t:        B x D_model
            ssm_state:  B x D_inner x N
            conv_state: B x D_inner x K_minus_1

        Returns:
            out_t:      B x D_model
            new_state:
                ssm_state:  B x D_inner x N
                conv_state: B x D_inner x K_minus_1
        """
        ssm_state, conv_state = state
        residual = x_t                                                 # B x D_model

        x_norm = self.norm(x_t)                                        # B x D_model
        xz = self.in_proj(x_norm)                                      # B x 2D_inner
        x_stream, z_stream = torch.split(xz, self.d_inner, dim=-1)     # each B x D_inner

        x_conv_t, conv_state = self.conv.step(x_stream, conv_state)    # B x D_inner
        x_conv_t = F.silu(x_conv_t)                                    # B x D_inner

        params_t = self.param_proj(x_conv_t)                           # B x [D + 2DN]
        delta_raw_t = params_t[:, : self.d_inner]                      # B x D

        split_size = self.d_inner * self.d_state
        BC_raw_t = params_t[:, self.d_inner :]                         # B x 2DN
        B_raw_t = BC_raw_t[:, :split_size]                             # B x DN
        C_raw_t = BC_raw_t[:, split_size:]                             # B x DN

        B_t = B_raw_t.view(x_t.shape[0], self.d_inner, self.d_state)   # B x D x N
        C_t = C_raw_t.view(x_t.shape[0], self.d_inner, self.d_state)   # B x D x N
        delta_t = F.softplus(delta_raw_t + self.delta_bias.unsqueeze(0))

        A = -torch.exp(self.A_log)                                     # D_inner x N

        y_t, ssm_state = selective_step(
            state=ssm_state,
            u_t=x_conv_t,
            delta_t=delta_t,
            A=A,
            B_t=B_t,
            C_t=C_t,
            D_skip=self.D_skip,
        )                                                              # y_t is B x D_inner

        y_t = y_t * F.silu(z_stream)                                   # B x D_inner
        out_t = self.out_proj(y_t) + residual                          # B x D_model
        return out_t, (ssm_state, conv_state)

For this tutorial, the right implementation target is not a fused high-performance block. It is a clean PyTorch block whose internals remain visible. The code should make the tensor geometry obvious. It should show exactly where $\Delta$, $B$, and $C$ come from. And it should make the selective scan appear as the conceptual center of the block, rather than burying it under dispatch logic.

That is what the block above is meant to do.

[Placeholder Figure 6]

What is simplified in the pedagogical block

A tutorial implementation should be faithful to the mechanism, but it need not reproduce every engineering choice of a production model.

The pedagogical version intentionally simplifies a few things.

First, it focuses on the core selective recurrence rather than on fusion or custom kernels. The point is to make the state update legible.

Second, it uses explicit tensor shapes and clear intermediate variables, even where a production implementation might collapse operations together for speed.

Third, it presents the block as a self-contained module rather than as one component in a larger codebase with dispatch layers, precision policies, and device-specific branches.

These simplifications are not concessions to correctness. They are decisions in favor of exposition. A reader who understands the clean reference block is in a much stronger position to reason about optimized versions later.


7. Stable parameterization and timescales

The selective recurrence is expressive, but expression alone is not enough. The model also needs a numerically sane way to parameterize its dynamics.

This section addresses two small but important implementation choices:

  • why $A$ is usually parameterized as a negative exponential
  • why $\Delta_t$ is usually forced to be positive

These choices are easy to treat as mere implementation folklore. They deserve a more principled explanation.

7.1 Why $A$ is parameterized as a negative exponential

In the practical channelwise recurrence, the transition factor is

\[\bar{A}_t = \exp(\Delta_t \odot A).\]

If $A$ were unconstrained, some of its entries could become strongly positive. Since $\Delta_t$ is also positive in most implementations, that would cause $\exp(\Delta_t A)$ to grow rather than decay. The resulting state dynamics could become highly unstable, especially early in training.

A common solution is to parameterize $A$ as

\[A = -\exp(A_{\log}).\]

This does two things immediately.

First, it ensures that every state mode is negative. That biases the dynamics toward decay rather than uncontrolled amplification.

Second, it lets optimization occur in log-space. The model learns $A_{\log}$, and the actual dynamical coefficients are obtained by exponentiating and negating. This is often a smoother way to learn a family of positive magnitudes than directly optimizing a constrained parameter.

The effect is simple: the model learns a collection of decay rates, and those decay rates are guaranteed to have the correct sign.

This does not solve every numerical problem. It does not guarantee perfect training behavior. But it provides a strong inductive bias toward well-behaved recurrent dynamics, which is especially valuable when the transition is itself being modulated token by token through $\Delta_t$.

7.2 Why $\Delta_t$ is made positive

The step size $\Delta_t$ has a direct interpretation as a learned local clock. That interpretation becomes difficult to sustain if $\Delta_t$ is allowed to change sign arbitrarily.

If $\Delta_t$ is positive and $A$ is negative, then the transition factor

\[\exp(\Delta_t A)\]

lies between zero and one and acts like a decay. Larger $\Delta_t$ means stronger forgetting. Smaller $\Delta_t$ means slower forgetting. This is a coherent and interpretable behavior.

If $\Delta_t$ were allowed to be negative, then the model could effectively reverse that logic and turn decays into growth factors. While such behavior is not mathematically forbidden, it makes the recurrence much harder to interpret and control.

For this reason, $\Delta_t$ is usually produced through a positive transform such as

\[\Delta_t = \mathrm{softplus}\!\left(\Delta_t^{\mathrm{raw}} + b_\Delta\right).\]

The bias term allows the model to initialize the local clocks in a sensible regime, while the softplus ensures positivity.

This is not merely a convenience. It turns $\Delta_t$ into a real timescale control variable.

7.3 The interaction of $A$ and $\Delta_t$

The transition factor depends on the product $\Delta_t A$. This means that $A$ and $\Delta_t$ do not act independently. They work together to determine the effective forgetting rate at each token.

A useful way to think about the relationship is:

  • $A$ defines the available spectrum of decays
  • $\Delta_t$ selects where, within that spectrum, the model operates at token $t$

This perspective helps explain why both are needed. If only $A$ were learned, the model would have a fixed family of timescales but no token-dependent control over them. If only $\Delta_t$ varied but $A$ were too narrow or poorly placed, the model would not have a rich enough collection of dynamical modes to exploit.

The power of the recurrence comes from their interaction.

7.4 What poor timescale parameterization looks like

This discussion becomes more concrete if we imagine failure modes.

If $\Delta_t$ collapses to very small values across most tokens, then the model barely advances its internal clock. Old state persists too strongly. Memory becomes sticky. The model can retain information, but it may struggle to clear stale content.

If $\Delta_t$ becomes very large across most tokens, then the opposite happens. The transition factor shrinks rapidly, and old state is aggressively erased. The model updates strongly from the current token, but long-range memory becomes weak.

If the entries of $A$ cluster too tightly, then the layer may fail to develop a diverse collection of timescales. If they are spread too unevenly, some modes may become effectively irrelevant while others dominate excessively.

The point is that memory quality is not determined by a single scalar knob. It emerges from the learned interaction between a shared spectrum of decay modes and a token-dependent local clock.

7.5 Practical interpretation

The advantage of this formulation is that the recurrence becomes interpretable in a way that generic hidden-state updates often are not.

If you inspect a trained model and observe large $\Delta_t$ values around punctuation or abrupt context changes, that is meaningful. It suggests the model is using those tokens to clear or refresh state. If you observe small $\Delta_t$ values around durable entities or structural anchors, that is also meaningful. It suggests the model is slowing down forgetting in those regions.

Likewise, inspecting the learned $A$ values reveals what timescales the layer has made available to itself. Some channels may specialize in shorter-range structure. Others may maintain slower traces.

This is one of the conceptual strengths of the Mamba formulation. The recurrence is not merely trainable. It is interpretable in terms of memory timescales and token-dependent control.

[Placeholder Figure 7]

Why log-parameterizing decay modes is often preferable

Suppose we want to learn a positive decay magnitude $\lambda > 0$ and then use $-\lambda$ as the corresponding stable mode. A direct parameterization of $\lambda$ requires either constrained optimization or ad hoc clipping. Neither is especially elegant.

By writing

\[\lambda = \exp(\theta), \qquad A = -\lambda,\]

we convert the positivity constraint into an unconstrained parameter $\theta \in \mathbb{R}$. The optimization now happens in log-space. Small additive changes in $\theta$ correspond to multiplicative changes in $\lambda$, which is often a more natural geometry for quantities that represent timescales.

This is particularly useful because useful decay rates often span several orders of magnitude. A linear parameterization can be awkward in such settings, while a log-parameterization handles multiplicative scale variation naturally.


8. Reference selective scan in PyTorch

We now have everything needed to write the recurrence directly.

This section is deliberately practical. The goal is to implement the selective scan in a form that is mathematically transparent and easy to verify. It is not yet the place to optimize aggressively. The correct first implementation keeps the sequence dependence explicit.

That point matters.

When people hear that the implementation is “vectorized,” they sometimes expect the entire time dimension to disappear into a single dense tensor expression. That is not what should happen here. The recurrence is genuinely sequential in time. The state at token $t$ depends on the state at token $t - 1$. A clean reference implementation should therefore preserve an explicit loop over the sequence length.

What can be vectorized is everything else:

  • batch dimension
  • channel dimension
  • state dimension

This gives us a readable and reasonably efficient baseline.

8.1 Shape conventions

We use the following tensor shapes throughout:

  • input stream
    $u \in \mathbb{R}^{B \times L \times D}$

  • token-dependent step sizes
    $\Delta \in \mathbb{R}^{B \times L \times D}$

  • shared dynamics
    $A \in \mathbb{R}^{D \times N}$

  • token-dependent write map
    $B \in \mathbb{R}^{B \times L \times D \times N}$

  • token-dependent read map
    $C \in \mathbb{R}^{B \times L \times D \times N}$

  • skip coefficient
    $D_{\mathrm{skip}} \in \mathbb{R}^{D}$

  • recurrent state
    $x_t \in \mathbb{R}^{B \times D \times N}$

  • output sequence
    $y \in \mathbb{R}^{B \times L \times D}$

This is the exact shape geometry that matters for the reference implementation. Once these dimensions are internalized, the recurrence becomes almost mechanical to code.

8.2 The update at one token

At a single token $t$, we extract

  • $u_t = u[:, t, :]$ with shape $B \times D$
  • $\Delta_t = \Delta[:, t, :]$ with shape $B \times D$
  • $B_t = B[:, t, :, :]$ with shape $B \times D \times N$
  • $C_t = C[:, t, :, :]$ with shape $B \times D \times N$

We then broadcast $\Delta_t$ and $u_t$ across the state dimension to form the update:

\[\bar{A}_t = \exp(\Delta_t \odot A),\] \[x_t = \bar{A}_t \odot x_{t-1} + (\Delta_t \odot B_t) \odot u_t,\] \[y_t = \langle C_t, x_t \rangle + D_{\mathrm{skip}} \odot u_t.\]

The only subtlety is the broadcasting. In code, both $u_t$ and $\Delta_t$ must be unsqueezed along the state dimension so that they become shape $B \times D \times 1$.

Once that is done, the update is fully elementwise over the $B \times D \times N$ state tensor.

8.3 The reference implementation

A clean reference implementation in PyTorch looks like this:

import torch


def selective_scan_reference(
    u: torch.Tensor,
    delta: torch.Tensor,
    A: torch.Tensor,
    B_map: torch.Tensor,
    C_map: torch.Tensor,
    D_skip: torch.Tensor,
    x0: torch.Tensor | None = None,
    return_final_state: bool = False,
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    """
    Reference selective scan.

    Shapes:
        u:       B x L x D
        delta:   B x L x D
        A:       D x N
        B_map:   B x L x D x N
        C_map:   B x L x D x N
        D_skip:  D
        x0:      B x D x N or None

    Returns:
        y:       B x L x D
        final_x: B x D x N if return_final_state is True
    """
    B_size, L, D_inner = u.shape
    N = A.shape[1]

    if x0 is None:
        state = torch.zeros(B_size, D_inner, N, device=u.device, dtype=u.dtype)
    else:
        state = x0

    outputs = []

    for t in range(L):
        u_t = u[:, t, :]                # B x D
        delta_t = delta[:, t, :]        # B x D
        B_t = B_map[:, t, :, :]         # B x D x N
        C_t = C_map[:, t, :, :]         # B x D x N

        y_t, state = selective_step(
            state=state,
            u_t=u_t,
            delta_t=delta_t,
            A=A,
            B_t=B_t,
            C_t=C_t,
            D_skip=D_skip,
        )
        outputs.append(y_t)

    y = torch.stack(outputs, dim=1)     # B x L x D

    if return_final_state:
        return y, state
    return y

This function is deliberately plain. It does not hide the recurrence behind advanced abstractions. It does not fuse operations. It does not attempt to eliminate the time loop. For a first implementation, those are strengths rather than weaknesses.

Each line has a direct mathematical interpretation.

  • A_bar_t is the token-dependent decay
  • write_t is the token-dependent write into state
  • state = A_bar_t * state + write_t is the recurrence itself
  • torch.sum(C_t * state, dim=-1) is the channelwise readout
  • the skip term adds the direct input path

That is exactly what a good reference implementation should look like.

8.4 Why the time loop should remain

A natural question is whether the loop over t can be removed.

For the straightforward reference implementation, the answer is no. The hidden state at token $t$ depends on the hidden state at token $t - 1$. This dependence is not an artifact of the code. It is the structure of the model.

The right way to think about vectorization here is therefore selective:

  • vectorize over batch
  • vectorize over channels
  • vectorize over state coordinates
  • keep the recurrence over time explicit

That is the most honest expression of the algorithm.

Later, when scan enters the picture, there is a deeper algebraic structure that allows chunkwise composition and parallelization. But that is a second step. The first step is to write the recurrence exactly as it is.

8.5 Common implementation mistakes

There are several easy ways to write code that looks plausible but computes the wrong recurrence.

The first is to forget to unsqueeze u_t or delta_t over the state dimension. Since both are shape $B \times D$, they must be reshaped to $B \times D \times 1$ before they can interact correctly with the $B \times D \times N$ state tensor.

The second is to sum over the wrong axis during readout. The contraction with $C_t$ should sum over the state dimension $N$, not over the channel dimension $D$.

The third is to confuse the batch size $B$ with the write map $B_t$. In code, it is worth naming the write map something like B_map to reduce ambiguity.

The fourth is to forget the skip path. The direct term $D_{\mathrm{skip}} \odot u_t$ is not decorative. It is part of the block’s expressive design.

The fifth is to silently change tensor layout conventions midway through the implementation. The recurrence becomes much easier to reason about if the input stream stays token-major with shape $B \times L \times D$, while the recurrent state stays shape $B \times D \times N$.

8.6 Why this implementation is the correct starting point

This reference scan is not only readable. It is also the anchor against which later optimizations should be checked.

A serious implementation workflow should always begin here:

  1. derive the recurrence carefully
  2. implement it in a clear reference form
  3. verify shapes and numerical behavior
  4. add block structure around it
  5. only then consider optimized scan kernels or fused implementations

Without a trusted reference path, optimization becomes much harder to validate. With a trusted reference path, every later speedup can be tested against something simple and correct.

That is why this implementation deserves to appear in the tutorial. It is not a disposable baseline. It is the conceptual center of the PyTorch story.

[Placeholder Figure 8]

Reverse-time gradient structure of the scan

Once the forward recurrence is clear, the backward structure is not difficult to understand.

At each token, the gradient with respect to the current state receives contributions from two sources:

  • the local readout at the same token
  • the dependence of the next state on the current state

This means that the backward pass is itself a reverse-time recurrence. If we let $g_t$ denote the gradient arriving at state $x_t$, then the gradient flowing to $x_{t-1}$ is obtained by multiplying $g_t$ by the local transition factor $\bar{A}_t$, along with the additional contributions created by the readout and the token-dependent parameters.

This structure is conceptually important for two reasons.

First, it explains why the recurrence remains visible in the gradient computation. The backward pass is not independent of the forward dynamics. It mirrors them in reverse time.

Second, it explains why a reference implementation is so valuable. Once a custom backward or scan optimization is introduced, it can be checked against the clear recurrence structure already present in the reference path.

The full reverse-time derivation is best treated later, after the block structure and scan intuition are complete.


9. Why earlier SSMs could use convolution and Mamba cannot

At this stage, the selective recurrence is clear and the reference implementation is straightforward. The next question is computational.

Why did earlier state space sequence models admit especially efficient training procedures, while Mamba requires a different algorithmic story?

The answer is not that Mamba ceases to be recurrent. It remains recurrent. The answer is that Mamba ceases to be time-invariant.

That distinction is fundamental.

9.1 Fixed SSMs are time-invariant

Consider the discrete recurrence

\[x_t = \bar{A} x_{t-1} + \bar{B} u_t, \qquad y_t = C x_t + D u_t.\]

If the coefficients $\bar{A}$, $\bar{B}$, and $C$ are fixed across time, then the system is linear and time-invariant. The same update is applied at every step. In that setting, the output at each time can be written as a sum of contributions from earlier inputs, each transformed by repeated applications of the same transition.

To see the pattern, expand the recurrence for a few steps:

\[\begin{aligned} x_t &= \bar{A} x_{t-1} + \bar{B} u_t \\ &= \bar{A} \left(\bar{A} x_{t-2} + \bar{B} u_{t-1}\right) + \bar{B} u_t \\ &= \bar{A}^2 x_{t-2} + \bar{A} \bar{B} u_{t-1} + \bar{B} u_t. \end{aligned}\]

Continuing this expansion gives

\[x_t = \bar{A}^t x_0 + \sum_{k=0}^{t-1} \bar{A}^k \bar{B} u_{t-k}.\]

Substituting into the readout produces an expression in which the output is a weighted sum of past inputs, with weights determined by repeated powers of the same transition matrix. The influence of the past is described by a single stationary kernel.

Once that stationary kernel exists, the recurrence can be interpreted as a convolution.

This is one of the central computational advantages of earlier state space sequence models. If the system is time-invariant, then training over a full sequence can exploit convolutional structure rather than stepping through the recurrence token by token in Python.

9.2 Why convolution disappears once the recurrence becomes selective

Mamba breaks exactly the assumption that made the convolutional view possible.

The selective recurrence is

\[x_t = \bar{A}_t \odot x_{t-1} + (\Delta_t \odot B_t) \odot u_t,\]

with

\[\bar{A}_t = \exp(\Delta_t \odot A).\]

The key objects now depend on $t$. The transition is no longer the same at every token. The write map is no longer the same at every token. The read map is no longer the same at every token.

That means there is no single stationary kernel that describes the entire sequence.

The contribution of an earlier token $u_s$ to a later state $x_t$ is no longer governed by repeated powers of one fixed matrix. Instead, it is governed by a product of token-dependent transitions:

\[x_t \sim \bar{A}_t \odot \bar{A}_{t-1} \odot \cdots \odot \bar{A}_{s+1} \odot \bigl((\Delta_s \odot B_s) \odot u_s\bigr).\]

That product depends on the exact path through the intervening tokens. Two positions at the same distance from $t$ need not contribute in the same way, because the sequence of transitions between them may differ.

This destroys the stationary convolution picture.

It is worth being explicit about what has been gained and what has been lost.

What has been gained is content-dependent memory control. The model can now preserve, erase, write, and read differently at each token.

What has been lost is time invariance. Once the coefficients depend on the token, the simple convolutional training strategy is no longer available.

This is the central computational tradeoff of Mamba.

9.3 The tradeoff is structural, not accidental

It would be a mistake to view this as a minor engineering inconvenience. It is a direct consequence of the modeling decision that gives Mamba its power.

If one wants a fixed recurrent memory with a stationary update rule, then convolutional structure is available. If one wants the memory policy itself to depend on the token, then the system becomes time-varying. A time-varying recurrence does not admit the same stationary convolution kernel.

This is not a weakness of the implementation. It is a structural fact about the model class.

Once this is understood, the right computational question is no longer “how do we turn Mamba back into a convolution.” That is the wrong objective. The right question is:

given that the recurrence is time-varying, how do we implement it efficiently enough to train at scale?

That question leads directly to scan.

9.4 A useful way to think about the loss of time invariance

The easiest intuition is to imagine two systems processing the same distance in time.

In a fixed SSM, moving forward by ten tokens always means applying the same ten-step operator, regardless of where those ten tokens occur in the sequence.

In Mamba, moving forward by ten tokens means applying a product of ten token-specific operators. The effect of those ten steps depends on the actual content of the intervening region.

That is precisely what makes the model selective. It is also what makes the computation more subtle.

[Placeholder Figure 9]

More formal contrast between time-invariant and time-varying recurrence

A linear time-invariant discrete system has the form

\[x_t = \bar{A} x_{t-1} + \bar{B} u_t.\]

Its solution can be written explicitly as

\[x_t = \bar{A}^t x_0 + \sum_{s=1}^{t} \bar{A}^{t-s} \bar{B} u_s.\]

The dependence on the past is therefore determined entirely by the lag $t - s$. This lag-only dependence is what produces a convolution kernel.

A time-varying system, by contrast, has the form

\[x_t = \bar{A}_t x_{t-1} + \bar{B}_t u_t.\]

Its solution is

\[x_t = \left(\prod_{j=1}^{t} \bar{A}_j\right) x_0 + \sum_{s=1}^{t} \left( \prod_{j=s+1}^{t} \bar{A}_j \right) \bar{B}_s u_s.\]

Now the effect of $u_s$ depends on the entire sequence of transitions between $s$ and $t$, not merely on the lag. This is the algebraic reason that the stationary convolution view disappears once the recurrence becomes selective.


10. The scan idea at the right conceptual level

Once the recurrence becomes time-varying, a naive implementation must process the sequence sequentially in time. That is exactly what the reference PyTorch code does. The implementation is correct, but it does not yet explain how the model can be trained efficiently on modern hardware.

The key idea is that sequential dependence does not necessarily imply that the only valid algorithm is a serial Python loop. There is more structure in the recurrence than first appears.

This section explains that structure at the right level of abstraction. The goal is not to give a full parallel algorithms lecture. The goal is to make clear why scan enters the Mamba story at all.

10.1 Start with the scalar recurrence

Consider the scalar recurrence

\[x_t = a_t x_{t-1} + b_t.\]

Here both $a_t$ and $b_t$ may vary with time. This is the simplest possible version of a selective recurrence. It already contains the essential algebraic structure.

Suppose we apply the update for two consecutive steps:

\[x_1 = a_1 x_0 + b_1,\] \[x_2 = a_2 x_1 + b_2.\]

Substituting the first equation into the second gives

\[x_2 = a_2 a_1 x_0 + a_2 b_1 + b_2.\]

This is the first important observation. Two recurrence steps can be summarized by a single affine map of the form

\[x_{\mathrm{out}} = \alpha x_{\mathrm{in}} + \beta,\]

where

\[\alpha = a_2 a_1, \qquad \beta = a_2 b_1 + b_2.\]

A segment of the sequence therefore has a compact summary: it scales the incoming state by one coefficient and then adds one bias term.

10.2 Segment summaries compose associatively

Now consider two sequence segments.

Suppose the first segment maps

\[x \mapsto \alpha_1 x + \beta_1,\]

and the second maps

\[x \mapsto \alpha_2 x + \beta_2.\]

Applying the first and then the second gives

\[x \mapsto \alpha_2(\alpha_1 x + \beta_1) + \beta_2 = (\alpha_2 \alpha_1) x + (\alpha_2 \beta_1 + \beta_2).\]

So the composition law is

\[(\alpha_2, \beta_2) \circ (\alpha_1, \beta_1) = (\alpha_2 \alpha_1,\; \alpha_2 \beta_1 + \beta_2).\]

This composition is associative.

That fact is the foundation of scan.

The recurrence may still be sequential if read token by token, but contiguous chunks of tokens can be summarized by objects that compose associatively. Once associativity is available, parallel prefix-style algorithms become possible.

That is the conceptual bridge from recurrence to efficient sequence computation.

10.3 What this means for Mamba

The full Mamba recurrence is more structured than the scalar example, but the same idea survives.

At each token, the state update can be viewed as an affine map on the incoming state:

\[x_t = \bar{A}_t \odot x_{t-1} + w_t,\]

where

\[w_t = (\Delta_t \odot B_t) \odot u_t.\]

For each channel and state coordinate, the update has exactly the same “multiply old state, then add write term” form as the scalar recurrence. That means token segments can be summarized and composed.

This is why scan is the right algorithmic concept for Mamba. The model is not time-invariant, so the stationary convolution trick is gone. But the recurrence is still affine in the state, which means token blocks can be reduced to chunk summaries that combine associatively.

That is enough structure to support parallelization.

10.4 What scan does and does not do

It is important to be precise here.

Scan does not make the recurrence cease to exist. The model remains recurrent.

Scan also does not mean that the full sequence magically becomes one dense matrix multiplication. The dependence structure is still that of a recurrence.

What scan does provide is a way to organize the computation so that chunk summaries can be formed and combined efficiently, rather than stepping one token at a time through a high-overhead serial path.

That distinction matters because it keeps the conceptual picture honest. Mamba is not secretly an attention model. It is not secretly a convolutional model either. It is a selective recurrent model whose recurrence has enough algebraic structure to admit scan-oriented computation.

10.5 The right takeaway

For the purposes of this discussion, the correct takeaway is simple:

  • the reference implementation keeps the time loop explicit because that is the clearest expression of the recurrence
  • the recurrence can still be parallelized at a higher algorithmic level because token segments admit associative summaries
  • this is why scan becomes the central implementation idea once selection breaks time invariance

The low-level details of the kernel engineering can wait. The conceptual reason for scan is the important point.

[Placeholder Figure 10]

A more algebraic statement of the scan structure

For the scalar recurrence

\[x_t = a_t x_{t-1} + b_t,\]

each token defines an affine map

\[f_t(x) = a_t x + b_t.\]

A sequence prefix $1, \dots, t$ therefore defines the composed map

\[f_t \circ f_{t-1} \circ \cdots \circ f_1.\]

The key point is that affine maps compose associatively. If a segment $S_1$ has summary $(\alpha_1, \beta_1)$ and a segment $S_2$ has summary $(\alpha_2, \beta_2)$, then the combined segment has summary

\[(\alpha_2 \alpha_1,\; \alpha_2 \beta_1 + \beta_2).\]

This turns the problem of recurrent propagation into a prefix-composition problem over an associative binary operation.

The channelwise Mamba recurrence inherits this structure coordinatewise. That is the algebraic reason scan is possible.


11. Streaming inference is the natural mode of Mamba

The same recurrence that makes Mamba selective also makes streaming inference straightforward.

This is one of the most practically attractive features of the architecture. At autoregressive inference time, the model does not need to revisit the full history explicitly. It only needs to carry forward a compact state.

That is a direct consequence of the fact that the model is recurrent.

11.1 What must be stored at inference time

A Mamba block contains two forms of state during step-by-step generation.

The first is the long-range recurrent state of the SSM:

\[x_t \in \mathbb{R}^{B \times D_{\mathrm{inner}} \times N}.\]

This is the persistent memory that carries information forward from token to token.

The second is the short-range cache required by the causal depthwise convolution. If the convolution kernel size is $K$, then each channel needs to remember the previous $K - 1$ inputs. The convolution cache therefore has shape

\[\mathrm{conv\_state} \in \mathbb{R}^{B \times D_{\mathrm{inner}} \times (K - 1)}.\]

These two tensors are the only per-layer quantities that must be preserved during generation.

That is the essential memory story.

11.2 Full forward and step mode compute the same recurrence

During training, it is natural to process a whole sequence at once. The block receives $h \in \mathbb{R}^{B \times L \times D_{\mathrm{model}}}$, computes the selective parameters for all tokens, and runs the recurrence through the entire sequence.

During autoregressive generation, the model instead receives one token at a time. At each step, it:

  1. updates the convolution cache with the new token features
  2. generates one-step values of $\Delta_t$, $B_t$, and $C_t$
  3. updates the recurrent state
  4. produces one output token representation
  5. carries the updated state forward

The underlying mathematics is the same in both modes. The only difference is whether the sequence dimension is processed all at once or incrementally.

This is why a good implementation should support both:

  • a full-sequence forward pass for training
  • a one-token step function for generation

Those two interfaces are not separate models. They are two ways of executing the same recurrence.

11.3 One-token step in words

The one-token update for a single block can be described cleanly.

Start from an input token representation

\[h_t \in \mathbb{R}^{B \times D_{\mathrm{model}}}.\]

Normalize it and project it into the recurrent stream and gate stream. Update the convolution cache and compute the locally mixed recurrent features. From those features, generate $\Delta_t$, $B_t$, and $C_t$. Use them to update the SSM state through the selective recurrence. Read out the recurrent output, gate it with the $z$-stream, project back to the model dimension, and return both the output and the updated caches.

Nothing conceptually new happens in step mode. The point of step mode is simply that the recurrence already tells us how to continue computation from a compact summary of the past.

11.4 Why this is attractive for autoregressive generation

The main advantage is that the memory required per layer does not grow with the length of the generated sequence.

For a transformer with standard key-value caching, the cache grows with the number of processed tokens. For a recurrent model such as Mamba, the state size is fixed once the layer dimensions are fixed.

This does not mean Mamba is automatically faster in every deployment setting. Real systems performance depends on many factors. But it does mean that the architecture has a fundamentally different inference memory profile. The amount of persistent memory per layer is determined by

  • the SSM state size $D_{\mathrm{inner}} \times N$
  • the convolution cache size $D_{\mathrm{inner}} \times (K - 1)$

and not by the current sequence length.

That is an important architectural property.

11.5 The most important implementation test

Because full forward and step mode are meant to compute the same recurrence, there is one test that matters more than any other during implementation:

run the block on a sequence in full mode, then run the same sequence token by token in step mode, and compare the outputs

If the implementation is correct, the two results should agree up to numerical tolerance.

This test is valuable for two reasons.

First, it verifies the recurrence itself. Any discrepancy usually indicates a genuine bug in state handling, convolution cache updates, broadcasting, or parameter generation.

Second, it protects the interface between training and generation. A model that trains with one recurrence and generates with a subtly different recurrence is incorrect.

For a tutorial implementation, this equivalence test should be treated as part of the architecture, not as an optional debugging extra.

A simple version of that test looks like this:

import torch


def test_full_vs_step_equivalence() -> None:
    torch.manual_seed(0)

    B_size = 2
    L = 16
    d_model = 32
    d_inner = 48
    d_state = 8
    d_conv = 4

    block = MinimalMambaBlock(
        d_model=d_model,
        d_inner=d_inner,
        d_state=d_state,
        d_conv=d_conv,
    )
    block.eval()

    x = torch.randn(B_size, L, d_model)

    with torch.no_grad():
        y_full = block(x)                                              # B x L x D_model

        state = block.init_state(
            batch_size=B_size,
            device=x.device,
            dtype=x.dtype,
        )

        outputs = []
        for t in range(L):
            y_t, state = block.step(x[:, t, :], state)                 # B x D_model
            outputs.append(y_t)

        y_step = torch.stack(outputs, dim=1)                           # B x L x D_model

    torch.testing.assert_close(y_step, y_full, atol=1e-5, rtol=1e-5)

[Placeholder Figure 11]

Memory accounting for one layer during generation

Suppose a layer uses inner width $D_{\mathrm{inner}}$, state size $N$, and convolution kernel size $K$. Then the persistent state required for generation is

  • SSM state: $B \times D_{\mathrm{inner}} \times N$
  • convolution cache: $B \times D_{\mathrm{inner}} \times (K - 1)$

The total persistent memory per layer is therefore

\[B \times D_{\mathrm{inner}} \times \bigl(N + K - 1\bigr).\]

The important point is that this quantity does not involve the processed sequence length. It is fixed once the architecture and batch size are fixed.


12. Putting the whole picture together

At this point, all the essential pieces are in place.

We began with the observation that recurrent sequence models are compression systems. A fixed recurrence compresses all past information through one stationary update rule. That works well when one global memory policy is enough. It becomes inadequate when different tokens should produce sharply different memory behaviors.

Mamba solves this by making the recurrence selective.

The architecture still maintains a finite state. It still processes the sequence causally. But the state update is no longer governed by one fixed rule. Instead, the model generates token-dependent quantities that determine:

  • how much old state should survive
  • what should be written into memory
  • how memory should be read into the output

That is the conceptual center of the model.

12.1 One pass through the block

It is useful to narrate a full Mamba block once from start to finish.

The block receives token representations at the model width. These are normalized and projected into two inner streams. One stream is destined for recurrent processing. The other becomes an output gate.

The recurrent stream is first passed through a short causal depthwise convolution. This injects local context before the memory policy is formed. The resulting features are then projected into the token-dependent quantities $\Delta$, $B$, and $C$. These quantities define, at each token, how quickly the internal dynamics should advance, how the current token should be written into state, and how the updated state should be read out.

The selective scan then applies the recurrence across the sequence. The resulting recurrent output is gated by the second stream, projected back to the model width, and added to the residual pathway.

That is the whole block. Once the selective recurrence is understood, everything else is scaffolding around it.

12.2 What Mamba adds to a plain SSM

The difference between a fixed SSM block and a Mamba block can now be stated precisely.

A fixed SSM gives the model a structured recurrent memory with stable and interpretable dynamics. The memory evolves according to a learned but stationary rule.

Mamba preserves that recurrent memory, but adds token-conditioned control over it. Specifically, it adds:

  • a token-dependent local clock through $\Delta_t$
  • a token-dependent write rule through $B_t$
  • a token-dependent read rule through $C_t$

These changes are enough to turn a fixed compression policy into an adaptive one.

That is why Mamba is more than “an SSM with a few gates.” The token-conditioned quantities are not cosmetic. They redefine how the memory system behaves at every step.

12.3 What remains fundamentally recurrent

At the same time, it is important not to overstate what has changed.

Mamba does not abandon recurrence. The model still propagates a finite state forward in time. The output at token $t$ still depends on the updated state at token $t$. The sequence dependence is still causal and directional.

That is why streaming inference is natural. It is also why the recurrence cannot simply be flattened into an ordinary stationary convolution once selection is introduced.

The model therefore lives in a very specific regime:

  • more adaptive than a fixed recurrent system
  • more stateful and streaming-friendly than full attention
  • computationally dependent on scan-style implementation ideas because selectivity destroys time invariance

This is the design space that makes Mamba interesting.

12.4 The main conceptual arc

The cleanest summary of the argument so far is the following.

An SSM provides a structured recurrent memory. Discretization turns that memory into a token-indexed recurrence. Mamba makes the recurrence selective by generating its update rule from token features. The resulting model remains causal and finite-state, which makes step-by-step generation natural. But because the update rule now varies with the token, the model loses the stationary convolution structure of earlier SSMs and must instead rely on scan-style computation.

That is the architecture in one paragraph.

12.5 Where the implementation goes next

With the conceptual picture complete, the remaining implementation tasks are now well-defined:

  1. build the block in PyTorch around the reference scan
  2. implement the one-token step path with explicit cached state
  3. verify that full forward and step mode agree numerically
  4. only then move on to kernel-level optimization and hardware-aware implementation details

That is the correct order of operations. The reference recurrence comes first. The block comes next. Scan optimization comes later.

The reason is simple. Once the selective recurrence is clear, the rest of Mamba becomes a matter of packaging, parameterization, and efficient execution.

[Placeholder Figure 12]

A compact comparison of the architectural tradeoff

The design tradeoff can be summarized in three lines.

A fixed recurrent SSM offers:

  • structured memory
  • stable timescales
  • time invariance and therefore convolutional structure

Mamba adds:

  • token-dependent control over forgetting, writing, and reading
  • stronger content sensitivity
  • natural streaming inference

Mamba gives up:

  • the stationary convolution view of the sequence computation

That final point is not a flaw. It is the price of the additional expressivity.


13. A practical implementation roadmap

The conceptual structure of Mamba is now in place. The remaining question is how to build it in a disciplined way.

The correct implementation order is not to begin with kernels, fusion, or benchmarking. It is to begin with the recurrence itself and add complexity only when each layer of the construction is already verified. This matters because Mamba is a model whose core idea is simple but whose implementation can become opaque very quickly if one starts from an optimized code path.

A clean implementation roadmap has seven stages.

13.1 Start with a scalar selective recurrence

The first stage is to implement the recurrence in the simplest possible setting. Use a single scalar state and a short synthetic sequence. At this stage, the goal is not performance and not even generality. The goal is to make the recurrence behavior visible.

In a scalar recurrence, one can inspect every quantity directly:

  • the previous state
  • the token-dependent decay
  • the write term
  • the resulting output

This is the right place to verify the conceptual role of $\Delta_t$, the effect of stronger or weaker forgetting, and the interaction between preservation and writing.

If the mechanism is not clear in the scalar case, it will not become clearer when buried inside a full block.

13.2 Implement a batched one-step selective update

The second stage is to generalize from the scalar toy setting to the actual tensor geometry of the model.

Write a function that performs a single selective update with shapes

  • input token stream at one time step
    $u_t \in \mathbb{R}^{B \times D}$

  • token-dependent step size
    $\Delta_t \in \mathbb{R}^{B \times D}$

  • shared dynamics
    $A \in \mathbb{R}^{D \times N}$

  • token-dependent write map
    $B_t \in \mathbb{R}^{B \times D \times N}$

  • token-dependent read map
    $C_t \in \mathbb{R}^{B \times D \times N}$

  • recurrent state
    $x_{t-1} \in \mathbb{R}^{B \times D \times N}$

This function should do exactly one thing: compute the next state and the current output. It should not know anything about sequence loops, convolutions, residual connections, or model-level packaging.

That separation is useful because it isolates the algebra of the selective recurrence from the machinery that surrounds it.

13.3 Implement the full reference scan

Once the one-step update is correct, the next stage is the full sequence recurrence.

This is the selective_scan_reference function described earlier. It should take tensors of shape

  • $u \in \mathbb{R}^{B \times L \times D}$
  • $\Delta \in \mathbb{R}^{B \times L \times D}$
  • $B \in \mathbb{R}^{B \times L \times D \times N}$
  • $C \in \mathbb{R}^{B \times L \times D \times N}$

and loop explicitly over sequence length while vectorizing over batch, channel, and state.

At this stage, it is worth being strict about shape assertions and intermediate variable names. A reference implementation is most valuable when it is easy to audit line by line.

13.4 Add a minimal causal depthwise convolution

The selective recurrence alone is not yet a Mamba block. The next stage is to add the short-range local mixer that produces better token features for parameter generation.

A minimal causal depthwise convolution is sufficient. The implementation should preserve three properties:

  • channelwise operation
  • strict causality
  • explicit cache behavior for later step mode

This is not a place to pursue cleverness. The goal is to make clear what local context is being supplied to the recurrent stream before the memory policy is formed.

13.5 Wrap the pieces into a self-contained Mamba block

With the scan and the convolution in place, one can build the first full block.

That block should include:

  • normalization
  • input projection into recurrent and gate streams
  • causal depthwise convolution on the recurrent stream
  • projection from recurrent features into $\Delta$, $B$, and $C$
  • stable parameterization of $A$ and $\Delta$
  • selective scan over the full sequence
  • output gating
  • projection back to model width
  • residual connection

The right goal at this stage is not a feature-complete library module. It is a compact block whose internal logic is obvious from the code.

13.6 Implement step mode with explicit cached state

Only after the full-sequence block is correct should one add token-by-token inference.

A one-token step path should carry exactly the state that the recurrence requires:

  • SSM state with shape $B \times D_{\mathrm{inner}} \times N$
  • convolution cache with shape $B \times D_{\mathrm{inner}} \times (K - 1)$

The step function should mirror the full block as closely as possible. It should not introduce a second, independently designed computation. The most reliable implementation is one in which the step path is visibly the same recurrence executed on a single token at a time.

13.7 Add numerical tests before any optimization work

Before thinking about scan kernels or hardware-aware execution, it is essential to lock down correctness.

The minimum useful tests are these.

First, a full-forward versus step-mode equivalence test. Run a block on the same input sequence in both modes and check that the outputs match to numerical tolerance.

Second, a causality test. Perturb future tokens and verify that earlier outputs do not change.

Third, a small gradient sanity check. On a toy configuration, verify that the loss produces finite gradients and that basic finite-difference checks agree with autodiff in the expected regime.

Fourth, a shape discipline test. This is less glamorous, but just as important. Explicit shape checks catch a large fraction of early implementation bugs.

This order of operations is not bureaucratic. It is the shortest path to a trustworthy implementation.

13.8 The principle behind the roadmap

The roadmap above is built around one rule:

every optimization should be downstream of a clear and trusted reference path

This matters especially for Mamba because the model contains both conceptual and systems complexity. The conceptual center is the selective recurrence. The systems complexity comes later, once one tries to make that recurrence fast on real hardware. If the reference path is weak, it becomes very difficult to tell whether an optimized implementation is fast and correct, or merely fast-looking.

For that reason, the right workflow is always:

  1. derive the recurrence
  2. implement the clear reference version
  3. wrap it into a block
  4. verify full and step execution
  5. only then optimize

That is the practical meaning of “from first principles” in code.


14. Conclusion

Mamba becomes much easier to understand once the architecture is reduced to its essential question:

How should a recurrent model decide what to keep, what to forget, and what to expose at each token?

A fixed recurrence answers that question with one stationary rule. That is often too rigid for information-dense sequences such as language. Some tokens should leave durable traces. Others should have only a local effect. A useful memory system must be able to distinguish between them.

Mamba makes that distinction by turning the recurrence itself into a token-dependent object.

The model still carries a finite state forward in time. It still updates that state causally. But the update is no longer governed by one fixed transition. Instead, the current token determines the effective forgetting rate, the write rule into state, and the read rule out of state. This is the key conceptual move of the architecture.

Seen from this perspective, the rest of the model falls into place.

The state space formulation provides the language of structured memory. The token-dependent local clock $\Delta_t$ gives the model direct control over memory timescales. The write map $B_t$ determines what enters memory. The read map $C_t$ determines how memory affects the output. The skip term preserves a direct local path. The surrounding block structure exists to generate these quantities from token features in a stable and trainable way.

This is why Mamba is best understood neither as “just another RNN” nor as “attention without attention.” It is more precise to say that Mamba is a selective state space model: a recurrent architecture with structured dynamics and token-conditioned memory control.

That modeling choice has a direct computational consequence. Once the update rule depends on the token, the model is no longer time-invariant. The stationary convolution view of earlier SSMs disappears. In its place, one must exploit the affine structure of the recurrence through scan-oriented computation. That algorithmic story is important, but it comes after the recurrence itself. One should first understand what is being computed before worrying about how it is accelerated.

For that reason, this post has centered the reference recurrence and the block built around it. A clean PyTorch implementation is not merely a pedagogical convenience. It is the correct foundation for every later question about step-mode inference, numerical stability, or kernel optimization.

The full conceptual arc can now be stated compactly.

A structured state space model gives a principled recurrent memory. Mamba makes that memory selective by generating its update rule from token features. The resulting architecture remains causal and finite-state, which makes streaming inference natural. At the same time, selection breaks time invariance, which is why scan becomes the right computational lens for efficient training.