Skip to content

Support vector distance calculation in MySQL. #623

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: trunk
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions client/mysql.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1172,6 +1172,7 @@ static COMMANDS commands[] = {
{"FROM_VECTOR", 0, nullptr, false, ""},
{"VECTOR_TO_STRING", 0, nullptr, false, ""},
{"VECTOR_DIM", 0, nullptr, false, ""},
{"DISTANCE", 0, nullptr, false, ""},
{"UCASE", 0, nullptr, false, ""},
{"UNCOMPRESS", 0, nullptr, false, ""},
{"UNCOMPRESSED_LENGTH", 0, nullptr, false, ""},
Expand Down
9 changes: 9 additions & 0 deletions share/messages_to_clients.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10430,6 +10430,15 @@ ER_EXCEEDS_VECTOR_MAX_DIMENSIONS
ER_TO_VECTOR_CONVERSION
eng "Data cannot be converted to a valid vector: '%.*s'"

ER_VECTOR_INVALID_DATA
eng "Invalid vector data provided to function %s."

ER_VECTOR_DIM_NO_EQ
eng "Vector dim not equal: %d != %d"

ER_UNKNOWN_DISTANCE_TYPE
eng "Unknown distance type: '%.*s'"

OBSOLETE_ER_EXTERNAL_UNSUPPORTED_INDEX_ALGORITHM
eng "This storage engine ignores the %s index algorithm."

Expand Down
1 change: 1 addition & 0 deletions sql/item_create.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,7 @@ static const std::pair<const char *, Create_func *> func_array[] = {
{"FROM_VECTOR", SQL_FN(Item_func_from_vector, 1)},
{"VECTOR_TO_STRING", SQL_FN(Item_func_from_vector, 1)},
{"VECTOR_DIM", SQL_FN(Item_func_vector_dim, 1)},
{"DISTANCE", SQL_FN(Item_func_vector_distance, 3)},
{"UCASE", SQL_FN(Item_func_upper, 1)},
{"UNCOMPRESS", SQL_FN(Item_func_uncompress, 1)},
{"UNCOMPRESSED_LENGTH", SQL_FN(Item_func_uncompressed_length, 1)},
Expand Down
4 changes: 4 additions & 0 deletions sql/item_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,10 @@ class Item_real_func : public Item_func {
set_data_type_double();
}

Item_real_func(const POS &pos, Item *a, Item *b, Item *c) : Item_func(pos, a, b, c) {
set_data_type_double();
}

explicit Item_real_func(mem_root_deque<Item *> *list) : Item_func(list) {
set_data_type_double();
}
Expand Down
68 changes: 68 additions & 0 deletions sql/item_strfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4326,6 +4326,74 @@ String *Item_func_from_vector::val_str_ascii(String *str) {
return &buffer;
}

double Item_func_vector_distance::val_real() {
assert(fixed);

String tmp_value1;
String tmp_value2;
String tmp_value3;
String *res1 = args[0]->val_str(&tmp_value1);
String *res2 = args[1]->val_str(&tmp_value2);
String *res3 = args[2]->val_str(&tmp_value3);

if ((null_value =
(!res1 || args[0]->null_value || !res2 || args[1]->null_value))) {
assert(is_nullable());
return 0.0;
}

if (res1 == nullptr || res2 == nullptr) {
my_error(ER_VECTOR_INVALID_DATA, MYF(0), func_name());
return error_real();
}

if (res3 == nullptr || res3->ptr() == nullptr) {
return error_real();
}

uint32 dimensions1 = get_dimensions(res1->length(), Field_vector::precision);
if (dimensions1 == UINT32_MAX) {
my_error(ER_TO_VECTOR_CONVERSION, MYF(0), res1->length(), res1->ptr());
return error_real();
}

uint32 dimensions2 = get_dimensions(res2->length(), Field_vector::precision);
if (dimensions2 == UINT32_MAX) {
my_error(ER_TO_VECTOR_CONVERSION, MYF(0), res2->length(), res2->ptr());
return error_real();
}

if (dimensions1 != dimensions2) {
my_error(ER_VECTOR_DIM_NO_EQ, MYF(0), dimensions1, dimensions2);
return error_real();
}

float distance = 0.0;
bool success = true;

// COSINE, DOT, and EUCLIDEAN
if (res3->length() == 3 && memcmp(res3->ptr(), "DOT", 3) == 0) {
success = vector_dot_distance(res1->ptr(), dimensions1, res2->ptr(),
dimensions2, &distance);
} else if (res3->length() == 6 && memcmp(res3->ptr(), "COSINE", 6) == 0) {
success = vector_cosine_distance(res1->ptr(), dimensions1, res2->ptr(),
dimensions2, &distance);
} else if (res3->length() == 9 && memcmp(res3->ptr(), "EUCLIDEAN", 9) == 0) {
success = vector_euclidean_distance(res1->ptr(), dimensions1, res2->ptr(),
dimensions2, &distance);
} else {
my_error(ER_UNKNOWN_DISTANCE_TYPE, MYF(0), res3->length(), res3->ptr());
return error_real();
}

if (!success) {
// ex. Division by zero
return error_real();
}

return distance;
}

String *Item_func_uncompress::val_str(String *str) {
assert(fixed);
String *res = args[0]->val_str(str);
Expand Down
13 changes: 13 additions & 0 deletions sql/item_strfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,19 @@ class Item_func_from_vector final : public Item_str_ascii_func {
String *val_str_ascii(String *str) override;
};

class Item_func_vector_distance : public Item_real_func {
public:
Item_func_vector_distance(const POS &pos, Item *ilist1, Item *ilist2,
Item *ilist3)
: Item_real_func(pos, ilist1, ilist2, ilist3) {
set_nullable(true);
}

double val_real() override;

const char *func_name() const override { return "vector_distance"; }
};

class Item_func_uncompress final : public Item_str_func {
String buffer;

Expand Down
63 changes: 62 additions & 1 deletion vector-common/vector_conversion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,4 +170,65 @@ bool from_vector_to_string(const char *input, uint32_t input_dims, char *output,

*max_output_len = total_length;
return false;
}
}

float vector_data_at(const char *input, uint32_t index) {
return *(float *)(input + index * sizeof(float));
}

bool vector_euclidean_distance(const char *input1, uint32_t input_dims1,
const char *input2, uint32_t input_dims2,
float *result) {
assert(input_dims1 == input_dims2);
float distance = 0.0f;
for (uint32_t i = 0; i < input_dims1; i++) {
float d1 = vector_data_at(input1, i);
float d2 = vector_data_at(input2, i);
float dif = d1 - d2;
distance += dif * dif;
}

*result = sqrt(distance);
return true;
}

bool vector_cosine_distance(const char *input1, uint32_t input_dims1,
const char *input2, uint32_t input_dims2,
float *result) {
assert(input_dims1 == input_dims2);
float dot_product = 0.0;
float norm1 = 0.0;
float norm2 = 0.0;

for (uint32_t i = 0; i < input_dims1; i++) {
float d1 = vector_data_at(input1, i);
float d2 = vector_data_at(input2, i);
dot_product += d1 * d2;
norm1 += d1 * d1;
norm2 += d2 * d2;
}

if (norm1 == 0.0 || norm2 == 0.0) {
return false;
}

float cos_sim = dot_product / (sqrt(norm1) * sqrt(norm2));
*result = 1.0 - cos_sim;
return true;
}

bool vector_dot_distance(const char *input1, uint32_t input_dims1,
const char *input2, uint32_t input_dims2,
float *result) {
assert(input_dims1 == input_dims2);
float dot_product = 0.0f;

for (uint32_t i = 0; i < input_dims1; i++) {
float d1 = vector_data_at(input1, i);
float d2 = vector_data_at(input2, i);
dot_product += d1 * d2;
}

*result = dot_product;
return true;
}
14 changes: 14 additions & 0 deletions vector-common/vector_conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,17 @@ bool from_string_to_vector(const CHARSET_INFO *cs, const char *input,

bool from_vector_to_string(const char *input, uint32_t input_dims, char *output,
uint32_t *max_output_len);

float vector_data_at(const char *input, uint32_t index);

bool vector_euclidean_distance(const char *input1, uint32_t input_dims1,
const char *input2, uint32_t input_dims2,
float *result);

bool vector_cosine_distance(const char *input1, uint32_t input_dims1,
const char *input2, uint32_t input_dims2,
float *result);

bool vector_dot_distance(const char *input1, uint32_t input_dims1,
const char *input2, uint32_t input_dims2,
float *result);