-
-
Notifications
You must be signed in to change notification settings - Fork 863
Expand file tree
/
Copy pathvector_query_ops.cpp
More file actions
297 lines (239 loc) · 13.5 KB
/
vector_query_ops.cpp
File metadata and controls
297 lines (239 loc) · 13.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
#include "vector_query_ops.h"
#include "string_utils.h"
#include "collection.h"
Option<bool> VectorQueryOps::parse_vector_query_str(const std::string& vector_query_str,
vector_query_t& vector_query,
const bool is_wildcard_query,
const Collection* coll,
const bool allow_empty_query) {
// FORMAT:
// field_name:([0.34, 0.66, 0.12, 0.68], k: 10)
size_t i = 0;
while(i < vector_query_str.size()) {
if(vector_query_str[i] == '(' || vector_query_str[i] == '[') {
// If we hit a bracket before a colon, it's a missing colon error
return Option<bool>(400, "Malformed vector query string: `:` is missing after the vector field name.");
}
if(vector_query_str[i] != ':') {
vector_query.field_name += vector_query_str[i];
i++;
} else {
// field name is done
i++;
StringUtils::trim(vector_query.field_name);
while(i < vector_query_str.size() && vector_query_str[i] != '(') {
i++;
}
if(vector_query_str[i] != '(') {
// missing "("
return Option<bool>(400, "Malformed vector query string.");
}
i++;
while(i < vector_query_str.size() && vector_query_str[i] != '[') {
i++;
}
if(vector_query_str[i] != '[') {
// missing opening "["
return Option<bool>(400, "Malformed vector query string.");
}
i++;
std::string values_str;
while(i < vector_query_str.size() && vector_query_str[i] != ']') {
values_str += vector_query_str[i];
i++;
}
if(vector_query_str[i] != ']') {
// missing closing "]"
return Option<bool>(400, "Malformed vector query string.");
}
i++;
std::vector<std::string> svalues;
StringUtils::split(values_str, svalues, ",");
for(auto& svalue: svalues) {
if(!StringUtils::is_float(svalue)) {
return Option<bool>(400, "Malformed vector query string: one of the vector values is not a float.");
}
vector_query.values.push_back(std::stof(svalue));
}
if(i == vector_query_str.size()-1) {
// missing params
if(vector_query.values.empty() && !allow_empty_query) {
// when query values are missing, atleast the `id` parameter must be present
return Option<bool>(400, "When a vector query value is empty, an `id` parameter must be present.");
}
return Option<bool>(true);
}
std::string param_str = vector_query_str.substr(i, (vector_query_str.size() - i));
std::vector<std::string> param_kvs;
StringUtils::split(param_str, param_kvs, ",");
for(size_t i = 0; i < param_kvs.size(); i++) {
auto& param_kv_str = param_kvs[i];
if(param_kv_str.back() == ')') {
param_kv_str.pop_back();
}
std::vector<std::string> param_kv;
StringUtils::split(param_kv_str, param_kv, ":");
if(param_kv.size() != 2) {
return Option<bool>(400, "Malformed vector query string.");
}
if(i < param_kvs.size() - 1 && param_kv[1].front() == '[' && param_kv[1].back() != ']') {
/*
Currently, we parse vector query parameters by splitting them with commas (e.g., alpha:0.7, k:100).
However, this approach has challenges when dealing with array parameters, where values are also separated by commas.
For instance, with a vector query like embedding:([], qs:[x, y]), our logic may incorrectly parse it as qs:[x and y]) due to the comma separator.
To address this issue, we have implemented a workaround.
If a comma-separated vector query parameter has '[' as its first character and does not have ']' as its last character, this means that the parameter is not yet complete.
In this case, we append the current parameter to the next parameter, and continue parsing the next parameter.
*/
param_kvs[i+1] = param_kv_str + "," + param_kvs[i+1];
continue;
}
if(param_kv[0] == "id") {
if(!vector_query.values.empty()) {
// cannot pass both vector values and id
return Option<bool>(400, "Malformed vector query string: cannot pass both vector query "
"and `id` parameter.");
}
Option<uint32_t> id_op = coll->doc_id_to_seq_id(param_kv[1]);
if(!id_op.ok()) {
return Option<bool>(400, "Document id referenced in vector query is not found.");
}
nlohmann::json document;
auto doc_op = coll->get_document_from_store(id_op.get(), document);
if(!doc_op.ok()) {
return Option<bool>(400, "Document id referenced in vector query is not found.");
}
if(!document.contains(vector_query.field_name) || !document[vector_query.field_name].is_array()) {
return Option<bool>(400, "Document referenced in vector query does not contain a valid "
"vector field.");
}
for(auto& fvalue: document[vector_query.field_name]) {
if(!fvalue.is_number()) {
return Option<bool>(400, "Document referenced in vector query does not contain a valid "
"vector field.");
}
vector_query.values.push_back(fvalue.get<float>());
}
vector_query.query_doc_given = true;
vector_query.seq_id = id_op.get();
}
if(param_kv[0] == "k") {
if(!StringUtils::is_uint32_t(param_kv[1])) {
return Option<bool>(400, "Malformed vector query string: `k` parameter must be an integer.");
}
vector_query.k = std::stoul(param_kv[1]);
}
if(param_kv[0] == "flat_search_cutoff") {
if(!StringUtils::is_uint32_t(param_kv[1])) {
return Option<bool>(400, "Malformed vector query string: "
"`flat_search_cutoff` parameter must be an integer.");
}
vector_query.flat_search_cutoff = std::stoi(param_kv[1]);
}
if(param_kv[0] == "distance_threshold") {
auto search_schema = const_cast<Collection*>(coll)->get_schema();
auto vector_field_it = search_schema.find(vector_query.field_name);
if(vector_field_it == search_schema.end()) {
return Option<bool>(400, "Malformed vector query string: could not find a field named "
"`" + vector_query.field_name + "`.");
}
if(!StringUtils::is_float(param_kv[1])) {
return Option<bool>(400, "Malformed vector query string: "
"`distance_threshold` parameter must be a float.");
}
auto distance_threshold = std::stof(param_kv[1]);
if(vector_field_it->vec_dist == cosine && (distance_threshold < 0.0 || distance_threshold > 2.0)) {
return Option<bool>(400, "Malformed vector query string: "
"`distance_threshold` parameter must be a float between 0.0-2.0.");
}
vector_query.distance_threshold = distance_threshold;
}
if(param_kv[0] == "alpha") {
if(!StringUtils::is_float(param_kv[1]) || std::stof(param_kv[1]) < 0.0 || std::stof(param_kv[1]) > 1.0) {
return Option<bool>(400, "Malformed vector query string: "
"`alpha` parameter must be a float between 0.0-1.0.");
}
vector_query.alpha = std::stof(param_kv[1]);
}
if(param_kv[0] == "ef") {
if(!StringUtils::is_uint32_t(param_kv[1]) || std::stoul(param_kv[1]) == 0) {
return Option<bool>(400, "Malformed vector query string: `ef` parameter must be a positive integer.");
}
vector_query.ef = std::stoul(param_kv[1]);
}
if(param_kv[0] == "queries") {
if(param_kv[1].front() != '[' || param_kv[1].back() != ']') {
return Option<bool>(400, "Malformed vector query string: "
"`queries` parameter must be a list of strings.");
}
param_kv[1].erase(0, 1);
param_kv[1].pop_back();
std::vector<std::string> qs;
StringUtils::split_list_with_backticks(param_kv[1], qs);
for(auto& q: qs) {
StringUtils::trim(q);
vector_query.queries.push_back(q);
}
}
if(param_kv[0] == "image") {
auto search_schema = const_cast<Collection*>(coll)->get_schema();
auto vector_field_it = search_schema.find(vector_query.field_name);
if(vector_field_it == search_schema.end()) {
return Option<bool>(400, "Malformed vector query string: could not find a field named "
"`" + vector_query.field_name + "`.");
}
if(vector_field_it->embed.empty()) {
return Option<bool>(400, "Malformed vector query string: `image` parameter is not supported "
"for this field.");
}
auto model_config = vector_field_it->embed["model_config"];
auto image_embedder_op = EmbedderManager::get_instance().get_image_embedder(model_config);
if(!image_embedder_op.ok()) {
return Option<bool>(400, "Malformed vector query string: could not get image embedder.");
}
auto image_embedder = image_embedder_op.get();
auto res = image_embedder->embed_documents({param_kv[1]});
if(res.empty() || !res[0].success) {
return Option<bool>(400, "Malformed vector query string: could not embed image.");
}
vector_query.values = res[0].embedding;
}
if(param_kv[0] == "query_weights") {
if(param_kv[1].front() != '[' || param_kv[1].back() != ']') {
return Option<bool>(400, "Malformed vector query string: "
"`query_weights` parameter must be a list of floats.");
}
param_kv[1].erase(0, 1);
param_kv[1].pop_back();
std::vector<std::string> ws;
StringUtils::split(param_kv[1], ws, ",");
for(auto& w: ws) {
StringUtils::trim(w);
if(!StringUtils::is_float(w)) {
return Option<bool>(400, "Malformed vector query string: "
"`query_weights` parameter must be a list of floats.");
}
vector_query.query_weights.push_back(std::stof(w));
}
}
}
if(vector_query.queries.size() != vector_query.query_weights.size() && !vector_query.query_weights.empty()) {
return Option<bool>(400, "Malformed vector query string: "
"`queries` and `query_weights` must be of the same length.");
}
if(!vector_query.query_weights.empty()) {
float sum = 0.0;
for(auto& w: vector_query.query_weights) {
sum += w;
}
if(sum != 1.0) {
return Option<bool>(400, "Malformed vector query string: "
"`query_weights` must sum to 1.0.");
}
}
return Option<bool>(true);
}
}
// We hit the end of the string without finding a colon
return Option<bool>(400, "Malformed vector query string: `:` is missing.");
}