velesdb_core/collection/search/query/aggregation/
mod.rs1#![allow(clippy::cast_precision_loss)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_sign_loss)]
17
18mod grouped;
19mod having;
20#[cfg(test)]
21mod having_tests;
22
23use super::where_eval::GraphMatchEvalCache;
24use crate::collection::types::Collection;
25use crate::error::Result;
26use crate::storage::{PayloadStorage, VectorStorage};
27use crate::velesql::{
28 AggregateArg, AggregateFunction, AggregateType, Aggregator, Query, SelectColumns,
29};
30use rayon::prelude::*;
31use rustc_hash::FxHasher;
32use std::collections::HashMap;
33use std::hash::{Hash, Hasher};
34
35#[derive(Clone)]
38pub(crate) struct GroupKey {
39 pub(crate) values: Vec<serde_json::Value>,
41 hash: u64,
43}
44
45impl GroupKey {
46 pub(crate) fn new(values: Vec<serde_json::Value>) -> Self {
47 let hash = Self::compute_hash(&values);
48 Self { values, hash }
49 }
50
51 fn compute_hash(values: &[serde_json::Value]) -> u64 {
52 let mut hasher = FxHasher::default();
53 for v in values {
54 Self::hash_value(v, &mut hasher);
55 }
56 hasher.finish()
57 }
58
59 fn hash_value(value: &serde_json::Value, hasher: &mut FxHasher) {
60 match value {
61 serde_json::Value::Null => 0u8.hash(hasher),
62 serde_json::Value::Bool(b) => {
63 1u8.hash(hasher);
64 b.hash(hasher);
65 }
66 serde_json::Value::Number(n) => {
67 2u8.hash(hasher);
68 if let Some(f) = n.as_f64() {
70 f.to_bits().hash(hasher);
71 }
72 }
73 serde_json::Value::String(s) => {
74 3u8.hash(hasher);
75 s.hash(hasher);
76 }
77 _ => {
78 4u8.hash(hasher);
80 value.to_string().hash(hasher);
81 }
82 }
83 }
84}
85
86impl Hash for GroupKey {
87 fn hash<H: Hasher>(&self, state: &mut H) {
88 self.hash.hash(state);
89 }
90}
91
92impl PartialEq for GroupKey {
93 fn eq(&self, other: &Self) -> bool {
94 self.hash == other.hash && self.values == other.values
96 }
97}
98
99impl Eq for GroupKey {}
100
101const PARALLEL_THRESHOLD: usize = 10_000;
104
105const CHUNK_SIZE: usize = 1000;
107
108impl Collection {
109 #[allow(clippy::too_many_lines)]
131 pub fn execute_aggregate(
132 &self,
133 query: &Query,
134 params: &HashMap<String, serde_json::Value>,
135 ) -> Result<serde_json::Value> {
136 let stmt = &query.select;
137
138 let aggregations: &[AggregateFunction] = match &stmt.columns {
140 SelectColumns::Aggregations(aggs) => aggs,
141 SelectColumns::Mixed { aggregations, .. } => aggregations,
142 _ => {
143 return Err(crate::error::Error::Config(
144 "execute_aggregate requires aggregation functions in SELECT".to_string(),
145 ))
146 }
147 };
148
149 if let Some(ref group_by) = stmt.group_by {
151 return self.execute_grouped_aggregate(
152 query,
153 aggregations,
154 &group_by.columns,
155 stmt.having.as_ref(),
156 params,
157 );
158 }
159
160 if stmt.having.is_some() {
162 return Err(crate::error::Error::Config(
163 "HAVING clause requires GROUP BY clause".to_string(),
164 ));
165 }
166
167 let where_clause = stmt.where_clause.as_ref();
168 let use_runtime_where_eval = where_clause.is_some_and(|cond| {
169 Self::condition_contains_graph_match(cond) || Self::condition_requires_vector_eval(cond)
170 });
171 let needs_vector_eval = where_clause.is_some_and(Self::condition_requires_vector_eval);
172
173 let filter = if use_runtime_where_eval {
176 None
177 } else {
178 where_clause.as_ref().map(|cond| {
179 let resolved = Self::resolve_condition_params(cond, params);
180 crate::filter::Filter::new(crate::filter::Condition::from(resolved))
181 })
182 };
183
184 let mut aggregator = Aggregator::new();
186
187 let columns_to_aggregate: std::collections::HashSet<&str> = aggregations
189 .iter()
190 .filter_map(|agg| match &agg.argument {
191 AggregateArg::Column(col) => Some(col.as_str()),
192 AggregateArg::Wildcard => None, })
194 .collect();
195
196 let has_count_star = aggregations
197 .iter()
198 .any(|agg| matches!(agg.argument, AggregateArg::Wildcard));
199
200 let payload_storage = self.payload_storage.read();
202 let vector_storage = self.vector_storage.read();
203 let ids: Vec<u64> = vector_storage.ids();
204 let total_count = ids.len();
205
206 let agg_result = if total_count >= PARALLEL_THRESHOLD && !use_runtime_where_eval {
208 let payloads: Vec<Option<serde_json::Value>> = ids
210 .iter()
211 .map(|&id| payload_storage.retrieve(id).ok().flatten())
212 .collect();
213
214 drop(payload_storage);
216 drop(vector_storage);
217
218 let columns_vec: Vec<String> = columns_to_aggregate
219 .iter()
220 .map(|s| (*s).to_string())
221 .collect();
222
223 let partial_aggregators: Vec<Aggregator> = payloads
225 .par_chunks(CHUNK_SIZE)
226 .map(|chunk| {
227 let mut chunk_agg = Aggregator::new();
228 for payload in chunk {
229 if let Some(ref f) = filter {
231 let matches = match payload {
232 Some(ref p) => f.matches(p),
233 None => f.matches(&serde_json::Value::Null),
234 };
235 if !matches {
236 continue;
237 }
238 }
239
240 if has_count_star {
242 chunk_agg.process_count();
243 }
244
245 if let Some(ref p) = payload {
247 for col in &columns_vec {
248 if let Some(value) = Self::get_nested_value(p, col) {
249 chunk_agg.process_value(col, value);
250 }
251 }
252 }
253 }
254 chunk_agg
255 })
256 .collect();
257
258 let mut final_agg = Aggregator::new();
260 for partial in partial_aggregators {
261 final_agg.merge(partial);
262 }
263 final_agg.finalize()
264 } else {
265 let mut graph_cache = GraphMatchEvalCache::default();
267 for id in ids {
268 let payload = payload_storage.retrieve(id).ok().flatten();
269
270 if use_runtime_where_eval {
271 let vector = if needs_vector_eval {
272 vector_storage.retrieve(id).ok().flatten()
273 } else {
274 None
275 };
276 if let Some(cond) = where_clause {
277 let matches = self.evaluate_where_condition_for_record(
278 cond,
279 id,
280 payload.as_ref(),
281 vector.as_deref(),
282 params,
283 &stmt.from_alias,
284 &mut graph_cache,
285 )?;
286 if !matches {
287 continue;
288 }
289 }
290 } else if let Some(ref f) = filter {
291 let matches = match payload {
292 Some(ref p) => f.matches(p),
293 None => f.matches(&serde_json::Value::Null),
294 };
295 if !matches {
296 continue;
297 }
298 }
299
300 if has_count_star {
302 aggregator.process_count();
303 }
304
305 if let Some(ref p) = payload {
307 for col in &columns_to_aggregate {
308 if let Some(value) = Self::get_nested_value(p, col) {
309 aggregator.process_value(col, value);
310 }
311 }
312 }
313 }
314 aggregator.finalize()
315 };
316 let mut result = serde_json::Map::new();
317
318 for agg in aggregations {
320 let key = if let Some(ref alias) = agg.alias {
321 alias.clone()
322 } else {
323 match &agg.argument {
324 AggregateArg::Wildcard => "count".to_string(),
325 AggregateArg::Column(col) => {
326 let prefix = match agg.function_type {
327 AggregateType::Count => "count",
328 AggregateType::Sum => "sum",
329 AggregateType::Avg => "avg",
330 AggregateType::Min => "min",
331 AggregateType::Max => "max",
332 };
333 format!("{prefix}_{col}")
334 }
335 }
336 };
337
338 let value = match (&agg.function_type, &agg.argument) {
339 (AggregateType::Count, AggregateArg::Wildcard) => {
340 serde_json::json!(agg_result.count)
341 }
342 (AggregateType::Count, AggregateArg::Column(col)) => {
343 let count = agg_result.counts.get(col.as_str()).copied().unwrap_or(0);
345 serde_json::json!(count)
346 }
347 (AggregateType::Sum, AggregateArg::Column(col)) => agg_result
348 .sums
349 .get(col.as_str())
350 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
351 (AggregateType::Avg, AggregateArg::Column(col)) => agg_result
352 .avgs
353 .get(col.as_str())
354 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
355 (AggregateType::Min, AggregateArg::Column(col)) => agg_result
356 .mins
357 .get(col.as_str())
358 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
359 (AggregateType::Max, AggregateArg::Column(col)) => agg_result
360 .maxs
361 .get(col.as_str())
362 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
363 _ => serde_json::Value::Null,
364 };
365
366 result.insert(key, value);
367 }
368
369 Ok(serde_json::Value::Object(result))
370 }
371
372 pub(crate) fn get_nested_value<'a>(
374 payload: &'a serde_json::Value,
375 path: &str,
376 ) -> Option<&'a serde_json::Value> {
377 let parts: Vec<&str> = path.split('.').collect();
378 let mut current = payload;
379
380 for part in parts {
381 match current {
382 serde_json::Value::Object(map) => {
383 current = map.get(part)?;
384 }
385 _ => return None,
386 }
387 }
388
389 Some(current)
390 }
391}