Skip to content

Speed up hierarchical k-means by computing distances in bulk #132384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the "Elastic License
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
* Public License v 1"; you may not use this file except in compliance with, at
* your election, the "Elastic License 2.0", the "GNU Affero General Public
* License v3.0 only", or the "Server Side Public License, v 1".
*/
package org.elasticsearch.benchmark.vector;

import org.apache.lucene.util.VectorUtil;
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
import org.elasticsearch.common.logging.LogConfigurator;
import org.elasticsearch.simdvec.ESVectorUtil;
import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.BenchmarkMode;
import org.openjdk.jmh.annotations.Fork;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Mode;
import org.openjdk.jmh.annotations.OutputTimeUnit;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;
import org.openjdk.jmh.infra.Blackhole;

import java.io.IOException;
import java.util.Random;
import java.util.concurrent.TimeUnit;

@BenchmarkMode(Mode.Throughput)
@OutputTimeUnit(TimeUnit.MILLISECONDS)
@State(Scope.Benchmark)
// first iteration is complete garbage, so make sure we really warmup
@Warmup(iterations = 4, time = 1)
// real iterations. not useful to spend tons of time here, better to fork more
@Measurement(iterations = 5, time = 1)
// engage some noise reduction
@Fork(value = 1)
public class DistanceBulkBenchmark {

static {
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
}

@Param({ "384", "782", "1024" })
int dims;

int length;

int numVectors = 4 * 100;
int numQueries = 10;

float[][] vectors;
float[][] queries;
float[] distances = new float[4];

@Setup
public void setup() throws IOException {
Random random = new Random(123);

this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;

vectors = new float[numVectors][dims];
for (float[] vector : vectors) {
for (int i = 0; i < dims; i++) {
vector[i] = random.nextFloat();
}
}

queries = new float[numQueries][dims];
for (float[] query : queries) {
for (int i = 0; i < dims; i++) {
query[i] = random.nextFloat();
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void squareDistance(Blackhole bh) {
for (int j = 0; j < numQueries; j++) {
float[] query = queries[j];
for (int i = 0; i < numVectors; i++) {
float[] vector = vectors[i];
float distance = VectorUtil.squareDistance(query, vector);
bh.consume(distance);
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void soarDistance(Blackhole bh) {
for (int j = 0; j < numQueries; j++) {
float[] query = queries[j];
for (int i = 0; i < numVectors; i++) {
float[] vector = vectors[i];
float distance = ESVectorUtil.soarDistance(query, vector, vector, 1.0f, 1.0f);
bh.consume(distance);
}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void squareDistanceBulk(Blackhole bh) {
for (int j = 0; j < numQueries; j++) {
float[] query = queries[j];
for (int i = 0; i < numVectors; i += 4) {
ESVectorUtil.squareDistanceBulk(query, vectors[i], vectors[i + 1], vectors[i + 2], vectors[i + 3], distances);
for (float distance : distances) {
bh.consume(distance);
}

}
}
}

@Benchmark
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
public void soarDistanceBulk(Blackhole bh) {
for (int j = 0; j < numQueries; j++) {
float[] query = queries[j];
for (int i = 0; i < numVectors; i += 4) {
ESVectorUtil.soarDistanceBulk(
query,
vectors[i],
vectors[i + 1],
vectors[i + 2],
vectors[i + 3],
vectors[i],
1.0f,
1.0f,
distances
);
for (float distance : distances) {
bh.consume(distance);
}

}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,79 @@ public static int quantizeVectorWithIntervals(float[] vector, int[] destination,
}
return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
}

/**
* Bulk computation of square distances between a query vector and four vectors.Result is stored in the provided distances array.
*
* @param q the query vector
* @param v0 the first vector
* @param v1 the second vector
* @param v2 the third vector
* @param v3 the fourth vector
* @param distances an array to store the computed square distances, must have length 4
*
* @throws IllegalArgumentException if the dimensions of the vectors do not match or if the distances array does not have length 4
*/
public static void squareDistanceBulk(float[] q, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
if (q.length != v0.length) {
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v0.length);
}
if (q.length != v1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v1.length);
}
if (q.length != v2.length) {
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v2.length);
}
if (q.length != v3.length) {
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v3.length);
}
if (distances.length != 4) {
throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length);
}
IMPL.squareDistanceBulk(q, v0, v1, v2, v3, distances);
}

/**
* Bulk computation of the soar distance for a vector to four centroids
* @param v1 the vector
* @param c0 the first centroid
* @param c1 the second centroid
* @param c2 the third centroid
* @param c3 the fourth centroid
* @param originalResidual the residual with the actually nearest centroid
* @param soarLambda the lambda parameter
* @param rnorm distance to the nearest centroid
* @param distances an array to store the computed soar distances, must have length 4
*/
public static void soarDistanceBulk(
float[] v1,
float[] c0,
float[] c1,
float[] c2,
float[] c3,
float[] originalResidual,
float soarLambda,
float rnorm,
float[] distances
) {
if (v1.length != c0.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c0.length);
}
if (v1.length != c1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c1.length);
}
if (v1.length != c2.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c2.length);
}
if (v1.length != c3.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c3.length);
}
if (v1.length != originalResidual.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + originalResidual.length);
}
if (distances.length != 4) {
throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length);
}
IMPL.soarDistanceBulk(v1, c0, c1, c2, c3, originalResidual, soarLambda, rnorm, distances);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -293,4 +293,30 @@ public int quantizeVectorWithIntervals(float[] vector, int[] destination, float
}
return sumQuery;
}

@Override
public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
distances[0] = VectorUtil.squareDistance(query, v0);
distances[1] = VectorUtil.squareDistance(query, v1);
distances[2] = VectorUtil.squareDistance(query, v2);
distances[3] = VectorUtil.squareDistance(query, v3);
}

@Override
public void soarDistanceBulk(
float[] v1,
float[] c0,
float[] c1,
float[] c2,
float[] c3,
float[] originalResidual,
float soarLambda,
float rnorm,
float[] distances
) {
distances[0] = soarDistance(v1, c0, originalResidual, soarLambda, rnorm);
distances[1] = soarDistance(v1, c1, originalResidual, soarLambda, rnorm);
distances[2] = soarDistance(v1, c2, originalResidual, soarLambda, rnorm);
distances[3] = soarDistance(v1, c3, originalResidual, soarLambda, rnorm);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,17 @@ float calculateOSQLoss(

int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);

void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances);

void soarDistanceBulk(
float[] v1,
float[] c0,
float[] c1,
float[] c2,
float[] c3,
float[] originalResidual,
float soarLambda,
float rnorm,
float[] distances
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -822,4 +822,122 @@ public int quantizeVectorWithIntervals(float[] vector, int[] destination, float
}
return sumQuery;
}

@Override
public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
FloatVector sv0 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES);
final int limit = FLOAT_SPECIES.loopBound(query.length);
int i = 0;
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, query, i);
FloatVector dv0 = FloatVector.fromArray(FLOAT_SPECIES, v0, i);
FloatVector dv1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
FloatVector dv2 = FloatVector.fromArray(FLOAT_SPECIES, v2, i);
FloatVector dv3 = FloatVector.fromArray(FLOAT_SPECIES, v3, i);
FloatVector diff0 = qv.sub(dv0);
sv0 = fma(diff0, diff0, sv0);
FloatVector diff1 = qv.sub(dv1);
sv1 = fma(diff1, diff1, sv1);
FloatVector diff2 = qv.sub(dv2);
sv2 = fma(diff2, diff2, sv2);
FloatVector diff3 = qv.sub(dv3);
sv3 = fma(diff3, diff3, sv3);
}
float distance0 = sv0.reduceLanes(VectorOperators.ADD);
float distance1 = sv1.reduceLanes(VectorOperators.ADD);
float distance2 = sv2.reduceLanes(VectorOperators.ADD);
float distance3 = sv3.reduceLanes(VectorOperators.ADD);

for (; i < query.length; i++) {
final float qValue = query[i];
final float diff0 = qValue - v0[i];
final float diff1 = qValue - v1[i];
final float diff2 = qValue - v2[i];
final float diff3 = qValue - v3[i];
distance0 = fma(diff0, diff0, distance0);
distance1 = fma(diff1, diff1, distance1);
distance2 = fma(diff2, diff2, distance2);
distance3 = fma(diff3, diff3, distance3);
}
distances[0] = distance0;
distances[1] = distance1;
distances[2] = distance2;
distances[3] = distance3;
}

@Override
public void soarDistanceBulk(
float[] v1,
float[] c0,
float[] c1,
float[] c2,
float[] c3,
float[] originalResidual,
float soarLambda,
float rnorm,
float[] distances
) {

FloatVector projVec0 = FloatVector.zero(FLOAT_SPECIES);
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector projVec3 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
final int limit = FLOAT_SPECIES.loopBound(v1.length);
int i = 0;
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
FloatVector c0Vec = FloatVector.fromArray(FLOAT_SPECIES, c0, i);
FloatVector c1Vec = FloatVector.fromArray(FLOAT_SPECIES, c1, i);
FloatVector c2Vec = FloatVector.fromArray(FLOAT_SPECIES, c2, i);
FloatVector c3Vec = FloatVector.fromArray(FLOAT_SPECIES, c3, i);
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec0 = v1Vec.sub(c0Vec);
FloatVector djkVec1 = v1Vec.sub(c1Vec);
FloatVector djkVec2 = v1Vec.sub(c2Vec);
FloatVector djkVec3 = v1Vec.sub(c3Vec);
projVec0 = fma(djkVec0, originalResidualVec, projVec0);
projVec1 = fma(djkVec1, originalResidualVec, projVec1);
projVec2 = fma(djkVec2, originalResidualVec, projVec2);
projVec3 = fma(djkVec3, originalResidualVec, projVec3);
acc0 = fma(djkVec0, djkVec0, acc0);
acc1 = fma(djkVec1, djkVec1, acc1);
acc2 = fma(djkVec2, djkVec2, acc2);
acc3 = fma(djkVec3, djkVec3, acc3);
}
float proj0 = projVec0.reduceLanes(ADD);
float dsq0 = acc0.reduceLanes(ADD);
float proj1 = projVec1.reduceLanes(ADD);
float dsq1 = acc1.reduceLanes(ADD);
float proj2 = projVec2.reduceLanes(ADD);
float dsq2 = acc2.reduceLanes(ADD);
float proj3 = projVec3.reduceLanes(ADD);
float dsq3 = acc3.reduceLanes(ADD);
// tail
for (; i < v1.length; i++) {
float v = v1[i];
float djk0 = v - c0[i];
float djk1 = v - c1[i];
float djk2 = v - c2[i];
float djk3 = v - c3[i];
proj0 = fma(djk0, originalResidual[i], proj0);
proj1 = fma(djk1, originalResidual[i], proj1);
proj2 = fma(djk2, originalResidual[i], proj2);
proj3 = fma(djk3, originalResidual[i], proj3);
dsq0 = fma(djk0, djk0, dsq0);
dsq1 = fma(djk1, djk1, dsq1);
dsq2 = fma(djk2, djk2, dsq2);
dsq3 = fma(djk3, djk3, dsq3);
}
distances[0] = dsq0 + soarLambda * proj0 * proj0 / rnorm;
distances[1] = dsq1 + soarLambda * proj1 * proj1 / rnorm;
distances[2] = dsq2 + soarLambda * proj2 * proj2 / rnorm;
distances[3] = dsq3 + soarLambda * proj3 * proj3 / rnorm;
}
}
Loading