HyperAIHyperAI

Command Palette

Search for a command to run...

Flash-KMeans: Fast and Memory-Efficient Exact K-Means

Abstract

kkk-means has historically been positioned primarily as an offline processing primitive, typically used for dataset organization or embedding preprocessing rather than as a first-class component in online systems. In this work, we revisit this classical algorithm under the lens of modern AI system design and enable kkk-means as an online primitive. We point out that existing GPU implementations of kkk-means remain fundamentally bottlenecked by low-level system constraints rather than theoretical algorithmic complexity. Specifically, the assignment stage suffers from a severe IO bottleneck due to the massive explicit materialization of the NimesKN imes KNimesK distance matrix in High Bandwidth Memory (HBM). Simultaneously, the centroid update stage is heavily penalized by hardware-level atomic write contention caused by irregular, scatter-style token aggregations. To bridge this performance gap, we propose flash-kmeans, an IO-aware and contention-free kkk-means implementation for modern GPU workloads. Flash-kmeans introduces two core kernel-level innovations: (1) FlashAssign, which fuses distance computation with an online argmin to completely bypass intermediate memory materialization; (2) sort-inverse update, which explicitly constructs an inverse mapping to transform high-contention atomic scatters into high-bandwidth, segment-level localized reductions. Furthermore, we integrate algorithm-system co-designs, including chunked-stream overlap and cache-aware compile heuristics, to ensure practical deployability. Extensive evaluations on NVIDIA H200 GPUs demonstrate that flash-kmeans achieves up to 17.9imes imesimes end-to-end speedup over best baselines, while outperforming industry-standard libraries like cuML and FAISS by 33imes imesimes and over 200imes imesimes, respectively.

One-sentence Summary

Researchers from UC Berkeley, MIT, and UT Austin propose flash-kmeans, an IO-aware GPU implementation that eliminates distance matrix materialization and atomic contention via FlashAssign and sort-inverse update, delivering up to 17.9x speedup for scalable online clustering in modern AI workloads.

Key Contributions

  • Existing GPU implementations of k-means are hindered by severe IO bottlenecks from materializing massive distance matrices and hardware-level atomic contention during centroid updates, preventing their use as efficient online primitives.
  • Flash-KMeans addresses these issues with two core kernel innovations: FlashAssign, which fuses distance computation with online argmin to bypass intermediate memory storage, and sort-inverse update, which transforms high-contention atomic scatters into localized reductions.
  • Evaluations on NVIDIA H200 GPUs show up to a 17.9x end-to-end speedup over best baselines and over 200x improvement compared to FAISS, while enabling seamless out-of-core execution on up to one billion points.

Introduction

K-means clustering is evolving from an offline data processing tool into a critical online primitive for modern AI systems, including vector quantization, sparse routing in large language models, and generative video pipelines. Despite this shift, existing GPU implementations fail to meet latency requirements because they remain bottlenecked by hardware constraints rather than algorithmic complexity. Prior approaches suffer from severe memory bandwidth waste due to the explicit materialization of massive distance matrices and suffer from hardware-level serialization caused by atomic write contention during centroid updates. To address these issues, the authors introduce Flash-KMeans, an exact and IO-aware implementation that fuses distance computation with online argmin to bypass intermediate memory storage and replaces irregular atomic scatters with a sort-inverse update strategy for efficient aggregation. This system-level redesign eliminates key bottlenecks, delivering up to 17.9 times end-to-end speedup over baselines while enabling scalable execution on datasets exceeding one billion points.

Method

The authors introduce flash-kmeans, a highly optimized implementation designed to overcome the severe memory and synchronization bottlenecks inherent in standard GPU-based kkk-means clustering. The methodology focuses on restructuring the execution dataflow to eliminate IO overheads and resolve write-side contention without altering the underlying mathematical objective.

FlashAssign: Materialization-Free Assignment

To address the memory wall caused by materializing the massive distance matrix DRN×KD \in \mathbb{R}^{N \times K}DRN×K, the authors propose FlashAssign. This module fuses the distance computation and row-wise reduction into a single streaming procedure. Instead of writing the full distance matrix to High Bandwidth Memory (HBM) and reading it back, FlashAssign maintains running states for the minimum distance and corresponding centroid index directly in registers.

The process utilizes an online argmin update. For each data point, the kernel scans centroids in tiles. It computes local distances on-chip, identifies the local minimum within the tile, and compares it with the running minimum to update the global assignment. This approach ensures that the N×KN \times KN×K distance matrix is never explicitly constructed in memory.

As illustrated in the framework diagram above, the algorithm loops over centroid tiles. For each point XiX_iXi, it computes distances against a centroid block CjC_jCj, finds the local minimum, and updates the global minimum index. By employing two-dimensional tiling and asynchronous prefetching, the kernel hides memory latency while ensuring that the IO complexity is reduced from O(NK)O(NK)O(NK) to O(Nd+Kd)O(Nd + Kd)O(Nd+Kd).

Sort-Inverse Update: Low-Contention Aggregation

In the centroid update stage, standard implementations suffer from severe atomic write contention because multiple threads frequently attempt to update the same centroid simultaneously using scatter-style atomic additions. To resolve this, the authors propose the sort-inverse update strategy.

The core idea is to transform the token-to-cluster update into a cluster-to-token gather operation. The system first applies an argsort operation to the assignment vector aaa to obtain a permutation index. This reorders the tokens such that identical cluster IDs are grouped into contiguous segments.

The figure below contrasts the standard scatter-style update with the proposed sort-inverse approach. In the standard method (a), tokens are scattered irregularly, causing conflicts across multiple blocks. In the sort-inverse method (b), tokens are sorted by cluster ID, creating contiguous segments. This allows each Cooperative Thread Array (CTA) to process a chunk of the sorted sequence, gathering features from the original matrix and accumulating partial sums entirely in fast on-chip memory. Global atomic operations are only issued at segment boundaries.

This reorganization drastically reduces the number of atomic operations from O(Nd)O(Nd)O(Nd) to O((K+N/BN)d)O((K + \lceil N/B_N \rceil)d)O((K+N/BN⌉)d). As shown in the execution timeline (c), this eliminates the frequent stalls caused by atomic lock contention, enabling contention-free memory writes and significantly accelerating the reduction phase.

Algorithm-System Co-design

To ensure deployability in real systems, flash-kmeans incorporates several system-level optimizations. For large-scale data that exceeds GPU memory, the authors implement a chunked stream overlap design. This partitions data into chunks and uses CUDA streams to coordinate asynchronous host-to-device transfers with computation, following a double-buffer streaming pattern. Additionally, a cache-aware compile heuristic is employed to select high-quality kernel configurations based on hardware characteristics and problem shape, minimizing the time-to-first-run overhead typically associated with exhaustive tuning.

Experiment

  • Efficiency evaluations demonstrate that flash-kmeans consistently outperforms optimized baselines across diverse workloads, achieving up to 17.9× speedup in compute-intensive scenarios and 15.3× in highly batched settings while preventing out-of-memory failures in memory-intensive regimes.
  • Kernel-level analysis confirms that custom FlashAssign and Sort-Inverse Update modules effectively eliminate distance matrix materialization and atomic contention bottlenecks, delivering up to 21.2× and 6.3× speedups respectively.
  • Large-scale out-of-core experiments validate that the system successfully processes datasets up to one billion points by bounding peak memory usage, resulting in 6.3× to 10.5× faster iteration times compared to the most robust existing baseline.
  • Algorithm-system co-design tests show that a cache-aware compile heuristic reduces configuration search time by up to 175× compared to exhaustive tuning while maintaining near-optimal runtime performance with negligible degradation.

Build AI with AI

From idea to launch — accelerate your AI development with free AI co-coding, out-of-the-box environment and best price of GPUs.

AI Co-coding
Ready-to-use GPUs
Best Pricing

HyperAI Newsletters

Subscribe to our latest updates
We will deliver the latest updates of the week to your inbox at nine o'clock every Monday morning
Powered by MailChimp