velesdb_core/collection/search/query/aggregation/
mod.rs1#![allow(clippy::cast_precision_loss)]
15#![allow(clippy::cast_possible_truncation)]
16#![allow(clippy::cast_sign_loss)]
17
18mod grouped;
19mod having;
20
21use super::where_eval::GraphMatchEvalCache;
22use crate::collection::types::Collection;
23use crate::error::Result;
24use crate::storage::{PayloadStorage, VectorStorage};
25use crate::velesql::{
26 AggregateArg, AggregateFunction, AggregateType, Aggregator, Query, SelectColumns,
27};
28use rayon::prelude::*;
29use rustc_hash::FxHasher;
30use std::collections::HashMap;
31use std::hash::{Hash, Hasher};
32
33#[derive(Clone)]
36pub(crate) struct GroupKey {
37 pub(crate) values: Vec<serde_json::Value>,
39 hash: u64,
41}
42
43impl GroupKey {
44 pub(crate) fn new(values: Vec<serde_json::Value>) -> Self {
45 let hash = Self::compute_hash(&values);
46 Self { values, hash }
47 }
48
49 fn compute_hash(values: &[serde_json::Value]) -> u64 {
50 let mut hasher = FxHasher::default();
51 for v in values {
52 Self::hash_value(v, &mut hasher);
53 }
54 hasher.finish()
55 }
56
57 fn hash_value(value: &serde_json::Value, hasher: &mut FxHasher) {
58 match value {
59 serde_json::Value::Null => 0u8.hash(hasher),
60 serde_json::Value::Bool(b) => {
61 1u8.hash(hasher);
62 b.hash(hasher);
63 }
64 serde_json::Value::Number(n) => {
65 2u8.hash(hasher);
66 if let Some(f) = n.as_f64() {
68 f.to_bits().hash(hasher);
69 }
70 }
71 serde_json::Value::String(s) => {
72 3u8.hash(hasher);
73 s.hash(hasher);
74 }
75 _ => {
76 4u8.hash(hasher);
78 value.to_string().hash(hasher);
79 }
80 }
81 }
82}
83
84impl Hash for GroupKey {
85 fn hash<H: Hasher>(&self, state: &mut H) {
86 self.hash.hash(state);
87 }
88}
89
90impl PartialEq for GroupKey {
91 fn eq(&self, other: &Self) -> bool {
92 self.hash == other.hash && self.values == other.values
94 }
95}
96
97impl Eq for GroupKey {}
98
99const PARALLEL_THRESHOLD: usize = 10_000;
102
103const CHUNK_SIZE: usize = 1000;
105
106impl Collection {
107 #[allow(clippy::too_many_lines)]
129 pub fn execute_aggregate(
130 &self,
131 query: &Query,
132 params: &HashMap<String, serde_json::Value>,
133 ) -> Result<serde_json::Value> {
134 let stmt = &query.select;
135
136 let aggregations: &[AggregateFunction] = match &stmt.columns {
138 SelectColumns::Aggregations(aggs) => aggs,
139 SelectColumns::Mixed { aggregations, .. } => aggregations,
140 _ => {
141 return Err(crate::error::Error::Config(
142 "execute_aggregate requires aggregation functions in SELECT".to_string(),
143 ))
144 }
145 };
146
147 if let Some(ref group_by) = stmt.group_by {
149 return self.execute_grouped_aggregate(
150 query,
151 aggregations,
152 &group_by.columns,
153 stmt.having.as_ref(),
154 params,
155 );
156 }
157
158 if stmt.having.is_some() {
160 return Err(crate::error::Error::Config(
161 "HAVING clause requires GROUP BY clause".to_string(),
162 ));
163 }
164
165 let where_clause = stmt.where_clause.as_ref();
166 let use_runtime_where_eval = where_clause.is_some_and(|cond| {
167 Self::condition_contains_graph_match(cond) || Self::condition_requires_vector_eval(cond)
168 });
169 let needs_vector_eval = where_clause.is_some_and(Self::condition_requires_vector_eval);
170
171 let filter = if use_runtime_where_eval {
174 None
175 } else {
176 where_clause.as_ref().map(|cond| {
177 let resolved = Self::resolve_condition_params(cond, params);
178 crate::filter::Filter::new(crate::filter::Condition::from(resolved))
179 })
180 };
181
182 let mut aggregator = Aggregator::new();
184
185 let columns_to_aggregate: std::collections::HashSet<&str> = aggregations
187 .iter()
188 .filter_map(|agg| match &agg.argument {
189 AggregateArg::Column(col) => Some(col.as_str()),
190 AggregateArg::Wildcard => None, })
192 .collect();
193
194 let has_count_star = aggregations
195 .iter()
196 .any(|agg| matches!(agg.argument, AggregateArg::Wildcard));
197
198 let payload_storage = self.payload_storage.read();
200 let vector_storage = self.vector_storage.read();
201 let ids: Vec<u64> = vector_storage.ids();
202 let total_count = ids.len();
203
204 let agg_result = if total_count >= PARALLEL_THRESHOLD && !use_runtime_where_eval {
206 let payloads: Vec<Option<serde_json::Value>> = ids
208 .iter()
209 .map(|&id| payload_storage.retrieve(id).ok().flatten())
210 .collect();
211
212 drop(payload_storage);
214 drop(vector_storage);
215
216 let columns_vec: Vec<String> = columns_to_aggregate
217 .iter()
218 .map(|s| (*s).to_string())
219 .collect();
220
221 let partial_aggregators: Vec<Aggregator> = payloads
223 .par_chunks(CHUNK_SIZE)
224 .map(|chunk| {
225 let mut chunk_agg = Aggregator::new();
226 for payload in chunk {
227 if let Some(ref f) = filter {
229 let matches = match payload {
230 Some(ref p) => f.matches(p),
231 None => f.matches(&serde_json::Value::Null),
232 };
233 if !matches {
234 continue;
235 }
236 }
237
238 if has_count_star {
240 chunk_agg.process_count();
241 }
242
243 if let Some(ref p) = payload {
245 for col in &columns_vec {
246 if let Some(value) = Self::get_nested_value(p, col) {
247 chunk_agg.process_value(col, value);
248 }
249 }
250 }
251 }
252 chunk_agg
253 })
254 .collect();
255
256 let mut final_agg = Aggregator::new();
258 for partial in partial_aggregators {
259 final_agg.merge(partial);
260 }
261 final_agg.finalize()
262 } else {
263 let mut graph_cache = GraphMatchEvalCache::default();
265 for id in ids {
266 let payload = payload_storage.retrieve(id).ok().flatten();
267
268 if use_runtime_where_eval {
269 let vector = if needs_vector_eval {
270 vector_storage.retrieve(id).ok().flatten()
271 } else {
272 None
273 };
274 if let Some(cond) = where_clause {
275 let matches = self.evaluate_where_condition_for_record(
276 cond,
277 id,
278 payload.as_ref(),
279 vector.as_deref(),
280 params,
281 stmt.from_alias.as_deref(),
282 &mut graph_cache,
283 )?;
284 if !matches {
285 continue;
286 }
287 }
288 } else if let Some(ref f) = filter {
289 let matches = match payload {
290 Some(ref p) => f.matches(p),
291 None => f.matches(&serde_json::Value::Null),
292 };
293 if !matches {
294 continue;
295 }
296 }
297
298 if has_count_star {
300 aggregator.process_count();
301 }
302
303 if let Some(ref p) = payload {
305 for col in &columns_to_aggregate {
306 if let Some(value) = Self::get_nested_value(p, col) {
307 aggregator.process_value(col, value);
308 }
309 }
310 }
311 }
312 aggregator.finalize()
313 };
314 let mut result = serde_json::Map::new();
315
316 for agg in aggregations {
318 let key = if let Some(ref alias) = agg.alias {
319 alias.clone()
320 } else {
321 match &agg.argument {
322 AggregateArg::Wildcard => "count".to_string(),
323 AggregateArg::Column(col) => {
324 let prefix = match agg.function_type {
325 AggregateType::Count => "count",
326 AggregateType::Sum => "sum",
327 AggregateType::Avg => "avg",
328 AggregateType::Min => "min",
329 AggregateType::Max => "max",
330 };
331 format!("{prefix}_{col}")
332 }
333 }
334 };
335
336 let value = match (&agg.function_type, &agg.argument) {
337 (AggregateType::Count, AggregateArg::Wildcard) => {
338 serde_json::json!(agg_result.count)
339 }
340 (AggregateType::Count, AggregateArg::Column(col)) => {
341 let count = agg_result.counts.get(col.as_str()).copied().unwrap_or(0);
343 serde_json::json!(count)
344 }
345 (AggregateType::Sum, AggregateArg::Column(col)) => agg_result
346 .sums
347 .get(col.as_str())
348 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
349 (AggregateType::Avg, AggregateArg::Column(col)) => agg_result
350 .avgs
351 .get(col.as_str())
352 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
353 (AggregateType::Min, AggregateArg::Column(col)) => agg_result
354 .mins
355 .get(col.as_str())
356 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
357 (AggregateType::Max, AggregateArg::Column(col)) => agg_result
358 .maxs
359 .get(col.as_str())
360 .map_or(serde_json::Value::Null, |v| serde_json::json!(v)),
361 _ => serde_json::Value::Null,
362 };
363
364 result.insert(key, value);
365 }
366
367 Ok(serde_json::Value::Object(result))
368 }
369
370 pub(crate) fn get_nested_value<'a>(
372 payload: &'a serde_json::Value,
373 path: &str,
374 ) -> Option<&'a serde_json::Value> {
375 let parts: Vec<&str> = path.split('.').collect();
376 let mut current = payload;
377
378 for part in parts {
379 match current {
380 serde_json::Value::Object(map) => {
381 current = map.get(part)?;
382 }
383 _ => return None,
384 }
385 }
386
387 Some(current)
388 }
389}