Skip to content

Reduce heap usage in hierarchical k-means #132391

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import java.io.IOException;
import java.util.Arrays;
import java.util.Objects;

/**
* An implementation of the hierarchical k-means algorithm that better partitions data than naive k-means
Expand Down Expand Up @@ -148,30 +149,32 @@ static FloatVectorValues createClusterSlice(int clusterSize, int cluster, FloatV
}

void updateAssignmentsWithRecursiveSplit(KMeansIntermediate current, int cluster, KMeansIntermediate subPartitions) {
if (subPartitions.centroids().length == 0) {
return; // nothing to do, sub-partitions is empty
}
int orgCentroidsSize = current.centroids().length;
int newCentroidsSize = current.centroids().length + subPartitions.centroids().length - 1;

// update based on the outcomes from the split clusters recursion
if (subPartitions.centroids().length > 1) {
float[][] newCentroids = new float[newCentroidsSize][dimension];
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);
float[][] newCentroids = new float[newCentroidsSize][];
System.arraycopy(current.centroids(), 0, newCentroids, 0, current.centroids().length);

// replace the original cluster
int origCentroidOrd = 0;
newCentroids[cluster] = subPartitions.centroids()[0];
// replace the original cluster
int origCentroidOrd = 0;
newCentroids[cluster] = subPartitions.centroids()[0];

// append the remainder
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
// append the remainder
System.arraycopy(subPartitions.centroids(), 1, newCentroids, current.centroids().length, subPartitions.centroids().length - 1);
assert Arrays.stream(newCentroids).allMatch(Objects::nonNull);

current.setCentroids(newCentroids);
current.setCentroids(newCentroids);

for (int i = 0; i < subPartitions.assignments().length; i++) {
// this is a new centroid that was added, and so we'll need to remap it
if (subPartitions.assignments()[i] != origCentroidOrd) {
int parentOrd = subPartitions.ordToDoc(i);
assert current.assignments()[parentOrd] == cluster;
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
}
for (int i = 0; i < subPartitions.assignments().length; i++) {
// this is a new centroid that was added, and so we'll need to remap it
if (subPartitions.assignments()[i] != origCentroidOrd) {
int parentOrd = subPartitions.ordToDoc(i);
assert current.assignments()[parentOrd] == cluster;
current.assignments()[parentOrd] = subPartitions.assignments()[i] + orgCentroidsSize - 1;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
package org.elasticsearch.index.codec.vectors.cluster;

import org.apache.lucene.index.FloatVectorValues;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.hnsw.IntToIntFunction;
import org.elasticsearch.index.codec.vectors.SampleReader;
Expand Down Expand Up @@ -70,17 +71,14 @@ private static boolean stepLloyd(
FloatVectorValues vectors,
IntToIntFunction translateOrd,
float[][] centroids,
float[][] nextCentroids,
FixedBitSet centroidChanged,
int[] centroidCounts,
int[] assignments,
NeighborHood[] neighborhoods
) throws IOException {
boolean changed = false;
int dim = vectors.dimension();
int[] centroidCounts = new int[centroids.length];

for (float[] nextCentroid : nextCentroids) {
Arrays.fill(nextCentroid, 0.0f);
}
centroidChanged.clear();
final float[] distances = new float[4];
for (int idx = 0; idx < vectors.size(); idx++) {
float[] vector = vectors.vectorValue(idx);
Expand All @@ -93,20 +91,39 @@ private static boolean stepLloyd(
bestCentroidOffset = getBestCentroid(centroids, vector, distances);
}
if (assignment != bestCentroidOffset) {
if (assignment != -1) {
centroidChanged.set(assignment);
}
centroidChanged.set(bestCentroidOffset);
assignments[vectorOrd] = bestCentroidOffset;
changed = true;
}
centroidCounts[bestCentroidOffset]++;
for (int d = 0; d < dim; d++) {
nextCentroids[bestCentroidOffset][d] += vector[d];
}
}
if (changed) {
Arrays.fill(centroidCounts, 0);
for (int idx = 0; idx < vectors.size(); idx++) {
final int assignment = assignments[translateOrd.apply(idx)];
if (centroidChanged.get(assignment)) {
float[] centroid = centroids[assignment];
if (centroidCounts[assignment]++ == 0) {
Arrays.fill(centroid, 0.0f);
}
float[] vector = vectors.vectorValue(idx);
for (int d = 0; d < dim; d++) {
centroid[d] += vector[d];
}
}
}

for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidCounts[clusterIdx] > 0) {
float countF = (float) centroidCounts[clusterIdx];
for (int d = 0; d < dim; d++) {
centroids[clusterIdx][d] = nextCentroids[clusterIdx][d] / countF;
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
if (centroidChanged.get(clusterIdx)) {
float count = (float) centroidCounts[clusterIdx];
if (count > 0) {
float[] centroid = centroids[clusterIdx];
for (int d = 0; d < dim; d++) {
centroid[d] /= count;
}
}
}
}
}
Expand Down Expand Up @@ -420,17 +437,18 @@ private void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansInterme
}

assert assignments.length == n;
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
FixedBitSet centroidChanged = new FixedBitSet(centroids.length);
int[] centroidCounts = new int[centroids.length];
for (int i = 0; i < maxIterations; i++) {
// This is potentially sampled, so we need to translate ordinals
if (stepLloyd(sampledVectors, translateOrd, centroids, nextCentroids, assignments, neighborhoods) == false) {
if (stepLloyd(sampledVectors, translateOrd, centroids, centroidChanged, centroidCounts, assignments, neighborhoods) == false) {
break;
}
}
// If we were sampled, do a once over the full set of vectors to finalize the centroids
if (sampleSize < n || maxIterations == 0) {
// No ordinal translation needed here, we are using the full set of vectors
stepLloyd(vectors, i -> i, centroids, nextCentroids, assignments, neighborhoods);
stepLloyd(vectors, i -> i, centroids, centroidChanged, centroidCounts, assignments, neighborhoods);
}
}

Expand Down