Flow Matching in 3 Minutes
Busy person’s intro to Flow Matching
In this post, I will try to build an intuitive understanding to Flow Matching, a framework used to train many state-of-the-art generative image models.
We start with 2 probability distributions $p_{\text{source}}$ and $p_{target}$ and our goal is to transform a point from $p_{\text{source}}$ to a point that could have been reasonably sampled from $p_\text{target}$.
2D Example
Suppose $p_{\text{source}}$ and $p_{target}$ are isotropic Gaussian distributions centred at $(0, -5)$ and $(5, 0)$ respectively. We sample one point from each distribution which happen to result in the mean points $(0, -5)$ (source) and $(5, 0)$ (target).
How do we move from the source to the target?
The simplest approach is a straight-line trajectory which minimizes the total distance we need to travel. We can take the vector difference between source and target $(5,0) - (0,-5) = (5,5)$ which represents the total “direction” and “magnitude” of movement required.
If we want to move in a single step, we move directly by $(5,5)$. But suppose we want to take five steps instead. We can decompose our direction of movement $(5,5) = 5 * (1,1)$, where we travel in the direction of $(1,1)$ five times. For five steps, the trajectory looks like:
$$ (0, -5) \rightarrow (1, -4) \rightarrow (2, -3) \rightarrow (3, -2) \rightarrow (4, -1) \rightarrow (5, 0) $$Notice that at each step, the direction of movement $(1,1)$ remains consistent.
Understanding Intermediate Steps
Now, consider an intermediate step, such as step 3 of 5. First, we must determine our current location. Since the motion is along a straight line and the source and target points are known, we can calculate the position by interpolating between them. To simplify, we normalize the interpolation to the range $[0, 1]$
$$x_t = x_{3/5} = (1-\frac{3}{5}) (0, -5) + (\frac{3}{5}) (5, 0) = (0, -2) + (3, 0) = (3, -2)$$Note that $(3,-2)$ is not a point directly sampled from either $p_{\text{source}}$ or $p_{target}$ but rather lies somewhere between the two distributions. The timestep $t=3/5$ helps us identify where we are in the transition. Critically, regardless of the timestep, the direction of movement remains consistent at $(1,1)$
This is basically all there is to flow matching! I’ve provided the code below which matches the intuition we built above.
# a single training step
t = torch.rand(1) # sample a single (normalized) timestep between [0, 1)
intermediate_t = (1-t) * source + (t * target) # current position at timestep t
direction = target - source # this is the direction we always have to move
# we give the model the timestep information so it knows where between source and target we are at
prediction = model(intermediate_t, t)
loss = ((direction - pred) **2).mean() # standard regression
loss.backward()
Formally, what we called “direction” $(1,1)$ is the rate of change of $x$ with respect to time $t$: $\frac{dx_t}{dt}$. From the interpretation $x_t= (1-t)x_0 + t x_1$, differentiating gives $$\frac{dx_t}{dt} = -x_0 + x_1 = x_1 - x_0$$ Thus, $\frac{dx_t}{dt}$ is equivalent to the direction of movement from $x_0$ (source) to $x_1$ (target)
Sampling
Typically $p_{source}$ is a distribution we can sample easily from (eg. Gaussian) while we cannot do the same for the more complex $p_{target}$. We usually only have data samples from $p_{target}$ which we use for training. To sample from $p_{target}$, we need to start from a source point sampled from $p_{source}$ and iteratively move in the predicted direction. For a trajectory with NUM_STEPS
, we scale the prediction each iteration:
source = torch.randn(1) # or any other source distribution
for t in range(NUM_STEPS):
prediction = model(source, t) # predict the direction to move
source = source + (1 / NUM_STEPS) * prediction # scale and move
source # this should be the target point
High-Dimensional Case
This intuition directly extends to higher dimensions. For example, in image generation, the source could be a high-dimensional Gaussian distribution ($H * W * 3$) and the target an image. The flow-matching process transforms samples from the source distribution (for eg. random noise) into samples resembling real images.
References
Lipman, Y., Chen, R. T. Q., Ben-Hamu, H., Nickel, M., & Le, M. (2023). Flow matching for generative modeling. arXiv. https://arxiv.org/abs/2210.02747
Liu, X., Gong, C., & Liu, Q. (2022). Flow straight and fast: Learning to generate and transfer data with rectified flow. arXiv. https://arxiv.org/abs/2209.03003