Token Merging for Fast Stable Diffusion: An Analysis

Sahil Sakhuja
8 min readJun 4, 2023

I have been pursuing my Masters in Data Science from Harvard University (Extension School) and that gives me the opportunity to work on various interesting projects. Most recently, in my class on Advanced Deep Learning, my project required me to conduct an evaluation of the approach proposed by the authors in the paper — Token Merging for Fast Stable Diffusion.

Image generation models like Stable Diffusion require significant time & cost even for inference tasks. While there has been development around speeding up these models (like Flash Attention, xFormers), there remains scope to make them faster since the computation scales with the square of the number or tokens / pixels in an image.

The authors, in their paper, propose Token Merging (ToMe) of redundant tokens (or pixels) in an image to reduce the amount of work to be done by the Transformer Layers. This can be done without any re-training requirements on the model and stacks on top of other approaches mentioned above.

I have, in the course of my project, showcased & demonstrated the following core outcomes:

  1. Token Merging leads to significant gains in processing efficiency
  2. This gain, however, comes at a cost in image quality — which is not apparent in the testing techniques used by the authors and requires more rigorous testing with real-life use cases

Introduction to ToMe (Token Merging) for Stable Diffusion

The Core Idea

Image generation models like Stable Diffusion require significant time & cost even for inference tasks. While there has been development around speeding up these models (like Flash Attention, xFormers), there remains scope to make them faster since the computation scales with the square of the number or tokens / pixels in an image.

Token Merging (ToMe) redundant tokens (or pixels) in an image to reduce the amount of work to be done by the Transformer Layers.

Merging & Unmerging tokens before each layer in the Diffusion Block

Advantages of ToMe:

  • No re-training required
  • Stacks benefits in speed & efficiency on top of other methods like xFormers

Implementation: Patching

ToMe has been implemented by the authors in a very innovative way. Rather than creating a complete different model on top — ToMe has been implemented as a “patch” which can be called on an existing Stable Diffusion model and it directly replaces the modules for the Transformer Blocks (known as “BasicTransformerBlock”) by its own “ToMeBlock”.

The simple call is:

tomesd.apply_patch(model, ratio=0.5)

Where,

  • tomesd — The ToMe library (installation covered in Hardware & installation section below)
  • apply_patch — The relevant method from the library
  • model — The object for the top level Stable Diffusion model
  • ratio — The percentage of tokens to be merged during the Token Merging process

This ToMe Block Module basically replaces the “forward” method of the original Basic Transformer Block Module and implements the merging / unmerging techniques before & after the actual matrix calculations.

ToMe Implemetation

How it works: Merging / Unmerging

ToMe distributes tokens into 2 sets — src (source) & dst (destination). This selection is done randomly with a 2X2 stride (configurable) over the image i.e. in each 2X2 box (i.e. 4 tokens), a random token is selected as a dst token and remaining 3 are chosen as src.

The authors have tested various different methods of identifying the destination token (random selection, row wise selection, etc.) and have found random selection within a 2X2 stride to work the best.

During merging, ToMe compares the src tokens to the dst tokens and merges the r% most similar tokens (based on cosine similarity — multiple methods tested by authors and cosine similarity found to be the best) into dst tokens. This r% is the “ratio” provided to the ToMe module during patching. The image to the right shows a sample of what this operation looks like, with the blue boxes denoting the dst tokens and the orange boxes denoting the src tokens.

Forward method replaced during patching computes the src / dst tokens, stores the indexes and generates the “merge” / “unmerge” functions.

These functions, as can be seen below, are applied to the inputs & outputs respectively of the attention and MLP layers.

The “compute_merge” function being called below separates out the indices between src and dst tokens and returns a set of functions. The functions prepended with “m_” are the merge functions, applied to the norm of the tensor. Post that, the relevant processing of the tensor is performed (attention or linear layer). Finally, the values are returned after applying the functions prepended with “u_” which are the unmerge functions.

Forward Method of the ToMe Block

ToMe Effectiveness: Paper Results

The testing of an approach like ToMe requires validating 2 things:

  1. Success Metrics: That the implementation actually LEADS TO improvements in the required outcomes (i.e. time for generating an image & memory used for generating an image)
  2. Control Metrics: That the implementation DOES NOT LEAD TO a significant change or deviation from the actual results that Stable Diffusion would have normally generated

The authors have followed a testing strategy that can be described as follows:

Test Dataset: Images generated for the 1-k ImageNet class labels

Success Metrics:

  • Generation Time: seconds / img
  • Generation Memory: GB / img

Control Metric (to ratify that images generated are similar to the actual ImageNet database):

  • Frechet Inception Distance (FID): The FID compares the distribution of generated images with the distribution of a set of real images (“ground truth”). Rather than directly comparing images pixel by pixel (for example, as done by the L2 norm), the FID compares the mean and standard deviation of the deepest layer in Inception v3.

Based on this testing — the authors have been able to showcase ~2X faster image generation times, with very little impact on FID:

However, I do see this testing strategy as not being a sufficient approach to really understand the impact that Token Merging has on image generation quality due to the following shortcomings:

  1. Test Outputs: The authors have used the ImageNet class labels as the prompts while generating the images. These labels are extremely basic and rudimentary at best and do not require a lot of different detailing and objects in 1 image, rather, refer to a single object per image.
  2. Using FID as a control metric: Since the prompt complexity, and thereby the output complexity, is extremely low, FID has shown minor change even with increasing merging. However, this metric does not take into consideration the actual “deviation” from the expected result from a Baseline image generated from a vanilla Stable Diffusion implementation vs one patched with ToMe.

ToMe Evaluation & Testing

To address the shortcomings of the testing approach, I have expanded on the testing as follows:

  1. I have used appropriate seeds to force Stable Diffusion model to give deterministic results
  2. Instead of testing using ImageNet 1-k classes, I have used a database of complex prompts to generate images
  3. Considering that my implementation of Stable Diffusion is giving deterministic results i.e. the same prompt always leads to the same image, I have then used MSE as a metric to measure the difference between generated images at different levels of Token Merging, when compared to the baseline image generated without merging. I have chosen not to use FID since FID measures the difference between databases of images — however, my intention was to do multiple test cases, in each checking the difference between only 2 images — one generated with Token Merging and the other generated without.

Stable Diffusion Model Setup for Inference: Seeding for Deterministic Outputs

The first step for me was to be able to use the Stable Diffusion Model to:

  1. Construct a Model Object
  2. Be able to use it to generate a deterministic image

I have customized the code provided by the Stable Diffusion team (under scripts/txt2img.py) to achieve this. I had to make the following changes to make this work effectively:

  1. Removal of non-required parts of the code
  2. Addition of ToMe remove patch and apply patch calls
  3. Addition of seeds at different places to ensure deterministic results — This was a critical component. In order to compare the performance of a vanilla Stable Diffusion model, as opposed to a ToMe patched Stable Diffusion model, it was imperative that the vanilla model would always give a deterministic output. As can be seen below, with the introduction of seeds at 3 different places, I was able to get deterministic results. The most important position for introducing the seed was the place where the “start_code” was being created.
Deterministic Results from Stable Diffusion Model

Test Results: Impact on Generated Image — MSE

Since the “seeds” are working effectively in the Stable Diffusion process, the end images being generated should have been deterministic i.e. the same across different runs. However, due to the dynamic selection of dst tokens from ToMe, the images are changing and we can see that the difference between images (as judged by MSE) is higher for higher values of r.

Mean & Std Dev of MSE increases with increase in the %age of tokens being merged

This can also be visually validated — the left-most images in the below grid are the ones generated without Token Merging, and as we go right, the ratio of tokens being merged is increasing — 10%, 25%, 50% & 75%. We can distinctly see that the images change significantly after we cross 50% token merging.

Initially, my intuition was that longer prompts would lead to more detailed images and hence, Token Merging would lead to larger differences from expected images. However, on analyzing the data, the observation is that shorter prompt length and fewer number of nouns (as a result of shorter prompt length) leads to higher MSEs.

Impact of length of Prompt & Nr of Nouns in Prompt on MSE @ 50% Tokens being merged

In summary, while the technique of Token Merging proposed by the authors does lead to significant efficiency improvements in running models like Stable Diffusion — it leads to changes in the model inference and can make the models more prone to incorrect or unexpected outcomes. Hence, the technique should to be used while considering this risk.

Links & Special Mentions

Original Paper:

Youtube Videos:

Git Repo:

Special thanks to Prof. Zoran B. Djordjevic and the teaching staff (shout out to Rahul Joglekar) for their guidance on this project.

--

--

Sahil Sakhuja

Driving business value from data | Leading a large cloud based data lake — working on data infrastructure, generating data insights and developing machine learn