velesdb_core/collection/search/query/aggregation/
mod.rs1#![allow(clippy::cast_precision_loss)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_sign_loss)]
17
18mod grouped;
19mod having;
20#[cfg(test)]
21mod having_tests;
22
23use super::where_eval::GraphMatchEvalCache;
24use crate::collection::types::Collection;
25use crate::error::Result;
26use crate::storage::{PayloadStorage, VectorStorage};
27use crate::velesql::{AggregateFunction, Aggregator, Query, SelectColumns};
28use rayon::prelude::*;
29use rustc_hash::FxHasher;
30use std::collections::HashMap;
31use std::hash::{Hash, Hasher};
32
33#[derive(Clone)]
36pub(crate) struct GroupKey {
37 pub(crate) values: Vec<serde_json::Value>,
39 hash: u64,
41}
42
43impl GroupKey {
44 pub(crate) fn new(values: Vec<serde_json::Value>) -> Self {
45 let hash = Self::compute_hash(&values);
46 Self { values, hash }
47 }
48
49 fn compute_hash(values: &[serde_json::Value]) -> u64 {
50 let mut hasher = FxHasher::default();
51 for v in values {
52 Self::hash_value(v, &mut hasher);
53 }
54 hasher.finish()
55 }
56
57 fn hash_value(value: &serde_json::Value, hasher: &mut FxHasher) {
58 match value {
59 serde_json::Value::Null => 0u8.hash(hasher),
60 serde_json::Value::Bool(b) => {
61 1u8.hash(hasher);
62 b.hash(hasher);
63 }
64 serde_json::Value::Number(n) => {
65 2u8.hash(hasher);
66 if let Some(f) = n.as_f64() {
68 f.to_bits().hash(hasher);
69 }
70 }
71 serde_json::Value::String(s) => {
72 3u8.hash(hasher);
73 s.hash(hasher);
74 }
75 _ => {
76 4u8.hash(hasher);
78 value.to_string().hash(hasher);
79 }
80 }
81 }
82}
83
84impl Hash for GroupKey {
85 fn hash<H: Hasher>(&self, state: &mut H) {
86 self.hash.hash(state);
87 }
88}
89
90impl PartialEq for GroupKey {
91 fn eq(&self, other: &Self) -> bool {
92 self.hash == other.hash && self.values == other.values
94 }
95}
96
97impl Eq for GroupKey {}
98
99pub(super) struct RuntimeWhereCtx<'a> {
101 pub(super) vector_storage: &'a dyn VectorStorage,
102 pub(super) stmt: &'a crate::velesql::SelectStatement,
103 pub(super) params: &'a HashMap<String, serde_json::Value>,
104 pub(super) needs_vector_eval: bool,
105 pub(super) graph_cache: &'a mut GraphMatchEvalCache,
106}
107
108struct SequentialAggCtx<'a> {
110 payload_storage: &'a dyn PayloadStorage,
111 vector_storage: &'a dyn VectorStorage,
112 stmt: &'a crate::velesql::SelectStatement,
113 params: &'a HashMap<String, serde_json::Value>,
114 filter: Option<&'a crate::filter::Filter>,
115 columns_to_aggregate: &'a [String],
116 has_count_star: bool,
117 use_runtime_where_eval: bool,
118}
119
120const PARALLEL_THRESHOLD: usize = 10_000;
123
124const CHUNK_SIZE: usize = 1000;
126
127impl Collection {
128 pub fn execute_aggregate(
150 &self,
151 query: &Query,
152 params: &HashMap<String, serde_json::Value>,
153 ) -> Result<serde_json::Value> {
154 let stmt = &query.select;
155
156 let aggregations: &[AggregateFunction] = match &stmt.columns {
157 SelectColumns::Aggregations(aggs) => aggs,
158 SelectColumns::Mixed { aggregations, .. } => aggregations,
159 _ => {
160 return Err(crate::error::Error::Config(
161 "execute_aggregate requires aggregation functions in SELECT".to_string(),
162 ))
163 }
164 };
165
166 if let Some(ref group_by) = stmt.group_by {
167 return self.execute_grouped_aggregate(
168 query,
169 aggregations,
170 &group_by.columns,
171 stmt.having.as_ref(),
172 params,
173 );
174 }
175
176 if stmt.having.is_some() {
177 return Err(crate::error::Error::Config(
178 "HAVING clause requires GROUP BY clause".to_string(),
179 ));
180 }
181
182 let agg_result = self.run_ungrouped_aggregation(stmt, aggregations, params)?;
183 Ok(Self::build_aggregate_result(aggregations, &agg_result))
184 }
185
186 fn run_ungrouped_aggregation(
188 &self,
189 stmt: &crate::velesql::SelectStatement,
190 aggregations: &[AggregateFunction],
191 params: &HashMap<String, serde_json::Value>,
192 ) -> Result<crate::velesql::AggregateResult> {
193 let where_clause = stmt.where_clause.as_ref();
194 let use_runtime_where_eval = where_clause.is_some_and(|cond| {
195 Self::condition_contains_graph_match(cond) || Self::condition_requires_vector_eval(cond)
196 });
197
198 let filter = Self::build_static_filter(where_clause, use_runtime_where_eval, params);
199 let (columns_vec, has_count_star) = Self::prepare_agg_columns(aggregations);
200
201 let payload_storage = self.payload_storage.read();
202 let vector_storage = self.vector_storage.read();
203 let ids: Vec<u64> = vector_storage.ids();
204
205 if ids.len() >= PARALLEL_THRESHOLD && !use_runtime_where_eval {
206 Ok(Self::run_parallel_path(
207 &ids,
208 &*payload_storage,
209 filter.as_ref(),
210 &columns_vec,
211 has_count_star,
212 ))
213 } else {
214 let ctx = SequentialAggCtx {
215 payload_storage: &*payload_storage,
216 vector_storage: &*vector_storage,
217 stmt,
218 params,
219 filter: filter.as_ref(),
220 columns_to_aggregate: &columns_vec,
221 has_count_star,
222 use_runtime_where_eval,
223 };
224 self.aggregate_sequential(&ids, &ctx)
225 }
226 }
227
228 fn run_parallel_path(
230 ids: &[u64],
231 payload_storage: &dyn PayloadStorage,
232 filter: Option<&crate::filter::Filter>,
233 columns_vec: &[String],
234 has_count_star: bool,
235 ) -> crate::velesql::AggregateResult {
236 let payloads: Vec<Option<serde_json::Value>> = ids
237 .iter()
238 .map(|&id| payload_storage.retrieve(id).ok().flatten())
239 .collect();
240
241 Self::aggregate_parallel(&payloads, filter, columns_vec, has_count_star)
242 }
243
244 pub(super) fn payload_passes_filter(
246 filter: &crate::filter::Filter,
247 payload: Option<&serde_json::Value>,
248 ) -> bool {
249 match payload {
250 Some(p) => filter.matches(p),
251 None => filter.matches(&serde_json::Value::Null),
252 }
253 }
254
255 pub(super) fn accumulate_record(
257 aggregator: &mut Aggregator,
258 payload: Option<&serde_json::Value>,
259 columns_to_aggregate: &[String],
260 has_count_star: bool,
261 ) {
262 if has_count_star {
263 aggregator.process_count();
264 }
265 if let Some(p) = payload {
266 for col in columns_to_aggregate {
267 if let Some(value) = Self::get_nested_value(p, col) {
268 aggregator.process_value(col, value);
269 }
270 }
271 }
272 }
273
274 fn aggregate_parallel(
276 payloads: &[Option<serde_json::Value>],
277 filter: Option<&crate::filter::Filter>,
278 columns_to_aggregate: &[String],
279 has_count_star: bool,
280 ) -> crate::velesql::AggregateResult {
281 let partial_aggregators: Vec<Aggregator> = payloads
282 .par_chunks(CHUNK_SIZE)
283 .map(|chunk| {
284 let mut chunk_agg = Aggregator::new();
285 for payload in chunk {
286 if let Some(f) = filter {
287 if !Self::payload_passes_filter(f, payload.as_ref()) {
288 continue;
289 }
290 }
291 Self::accumulate_record(
292 &mut chunk_agg,
293 payload.as_ref(),
294 columns_to_aggregate,
295 has_count_star,
296 );
297 }
298 chunk_agg
299 })
300 .collect();
301
302 let mut final_agg = Aggregator::new();
303 for partial in partial_aggregators {
304 final_agg.merge(partial);
305 }
306 final_agg.finalize()
307 }
308
309 pub(super) fn runtime_where_passes(
311 &self,
312 id: u64,
313 payload: Option<&serde_json::Value>,
314 ctx: &mut RuntimeWhereCtx<'_>,
315 ) -> Result<bool> {
316 let vector = if ctx.needs_vector_eval {
317 ctx.vector_storage.retrieve(id).ok().flatten()
318 } else {
319 None
320 };
321 match ctx.stmt.where_clause.as_ref() {
322 Some(cond) => self.evaluate_where_condition_for_record(
323 cond,
324 id,
325 payload,
326 vector.as_deref(),
327 ctx.params,
328 &ctx.stmt.from_alias,
329 ctx.graph_cache,
330 ),
331 None => Ok(true),
332 }
333 }
334
335 fn record_passes_filter(
337 &self,
338 id: u64,
339 payload: Option<&serde_json::Value>,
340 ctx: &SequentialAggCtx<'_>,
341 needs_vector_eval: bool,
342 graph_cache: &mut GraphMatchEvalCache,
343 ) -> Result<bool> {
344 if ctx.use_runtime_where_eval {
345 let mut where_ctx = RuntimeWhereCtx {
346 vector_storage: ctx.vector_storage,
347 stmt: ctx.stmt,
348 params: ctx.params,
349 needs_vector_eval,
350 graph_cache,
351 };
352 self.runtime_where_passes(id, payload, &mut where_ctx)
353 } else if let Some(f) = ctx.filter {
354 Ok(Self::payload_passes_filter(f, payload))
355 } else {
356 Ok(true)
357 }
358 }
359
360 fn aggregate_sequential(
362 &self,
363 ids: &[u64],
364 ctx: &SequentialAggCtx<'_>,
365 ) -> Result<crate::velesql::AggregateResult> {
366 let needs_vector_eval = ctx
367 .stmt
368 .where_clause
369 .as_ref()
370 .is_some_and(Self::condition_requires_vector_eval);
371 let mut aggregator = Aggregator::new();
372 let mut graph_cache = GraphMatchEvalCache::default();
373
374 for &id in ids {
375 let payload = ctx.payload_storage.retrieve(id).ok().flatten();
376 if !self.record_passes_filter(
377 id,
378 payload.as_ref(),
379 ctx,
380 needs_vector_eval,
381 &mut graph_cache,
382 )? {
383 continue;
384 }
385 Self::accumulate_record(
386 &mut aggregator,
387 payload.as_ref(),
388 ctx.columns_to_aggregate,
389 ctx.has_count_star,
390 );
391 }
392 Ok(aggregator.finalize())
393 }
394
395 fn build_aggregate_result(
397 aggregations: &[AggregateFunction],
398 agg_result: &crate::velesql::AggregateResult,
399 ) -> serde_json::Value {
400 let mut result = serde_json::Map::new();
401
402 for agg in aggregations {
403 let key = Self::aggregation_result_key(agg);
404 let value = Self::aggregation_result_value(agg, agg_result);
405 result.insert(key, value);
406 }
407
408 serde_json::Value::Object(result)
409 }
410
411 pub(crate) fn get_nested_value<'a>(
416 payload: &'a serde_json::Value,
417 path: &str,
418 ) -> Option<&'a serde_json::Value> {
419 let parts: Vec<&str> = path.split('.').collect();
420 let mut current = payload;
421
422 for part in parts {
423 match current {
424 serde_json::Value::Object(map) => {
425 current = map.get(part)?;
426 }
427 _ => return None,
428 }
429 }
430
431 Some(current)
432 }
433}