1use std::{
42 collections::HashMap,
43 io::{self, Cursor},
44 sync::Arc,
45};
46
47use ahash::AHashMap;
48use vibesql_storage::Row;
49use vibesql_types::SqlValue;
50
51use super::{
52 row_serialization::{deserialize_row, serialize_row},
53 MemoryController, MemoryReservation, SpillFile,
54};
55use crate::select::grouping::AggregateAccumulator;
56
57#[derive(Debug, Clone)]
59pub struct ExternalAggregateConfig {
60 pub num_partitions: usize,
62
63 pub max_groups_per_partition: usize,
65}
66
67impl Default for ExternalAggregateConfig {
68 fn default() -> Self {
69 Self {
70 num_partitions: 64, max_groups_per_partition: 10_000, }
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct AggregateSpec {
79 pub function_name: String,
81
82 pub distinct: bool,
84
85 pub value_index: usize,
87}
88
89struct Partition {
91 groups: AHashMap<Vec<u8>, (Vec<SqlValue>, Vec<AggregateAccumulator>)>,
95
96 memory_bytes: usize,
98
99 spilled: bool,
101
102 spill_file: Option<SpillFile>,
104
105 rows_spilled: usize,
107}
108
109impl Partition {
110 fn new() -> Self {
111 Self {
112 groups: AHashMap::new(),
113 memory_bytes: 0,
114 spilled: false,
115 spill_file: None,
116 rows_spilled: 0,
117 }
118 }
119
120 fn estimate_group_memory(
122 key_values: &[SqlValue],
123 accumulators: &[AggregateAccumulator],
124 ) -> usize {
125 let key_size: usize = key_values.iter().map(|v| v.estimated_size_bytes()).sum();
126 let acc_size = std::mem::size_of_val(accumulators)
127 + accumulators.iter().map(estimate_accumulator_memory).sum::<usize>();
128 key_size + acc_size + 64 }
130}
131
132fn estimate_accumulator_memory(acc: &AggregateAccumulator) -> usize {
134 match acc {
135 AggregateAccumulator::Count { seen: Some(set), .. } => set.len() * 48,
136 AggregateAccumulator::Sum { seen: Some(set), .. } => set.len() * 48,
137 AggregateAccumulator::Avg { seen: Some(set), .. } => set.len() * 48,
138 AggregateAccumulator::Min { seen: Some(set), .. } => set.len() * 48,
139 AggregateAccumulator::Max { seen: Some(set), .. } => set.len() * 48,
140 _ => 0,
141 }
142}
143
144pub struct ExternalAggregate {
148 reservation: MemoryReservation,
150
151 #[allow(dead_code)]
153 config: ExternalAggregateConfig,
154
155 aggregate_specs: Vec<AggregateSpec>,
157
158 num_key_columns: usize,
160
161 partitions: Vec<Partition>,
163
164 partition_mask: usize,
166
167 total_groups: usize,
169}
170
171impl ExternalAggregate {
172 pub fn new(
174 controller: &Arc<MemoryController>,
175 num_key_columns: usize,
176 aggregate_specs: Vec<AggregateSpec>,
177 ) -> Self {
178 Self::with_config(
179 controller,
180 num_key_columns,
181 aggregate_specs,
182 ExternalAggregateConfig::default(),
183 )
184 }
185
186 pub fn with_config(
188 controller: &Arc<MemoryController>,
189 num_key_columns: usize,
190 aggregate_specs: Vec<AggregateSpec>,
191 config: ExternalAggregateConfig,
192 ) -> Self {
193 let num_partitions = config.num_partitions.next_power_of_two();
195
196 let mut partitions = Vec::with_capacity(num_partitions);
197 for _ in 0..num_partitions {
198 partitions.push(Partition::new());
199 }
200
201 Self {
202 reservation: controller.create_reservation(),
203 config: ExternalAggregateConfig { num_partitions, ..config },
204 aggregate_specs,
205 num_key_columns,
206 partitions,
207 partition_mask: num_partitions - 1,
208 total_groups: 0,
209 }
210 }
211
212 pub fn add_row(&mut self, row: &[SqlValue]) -> io::Result<()> {
216 let key_values = &row[..self.num_key_columns];
218
219 let partition_idx = self.compute_partition(key_values);
221 let partition = &mut self.partitions[partition_idx];
222
223 if partition.spilled {
225 self.spill_row_to_partition(partition_idx, row)?;
226 return Ok(());
227 }
228
229 let key_bytes = serialize_key(key_values);
231
232 if let Some((_, accumulators)) = partition.groups.get_mut(&key_bytes) {
234 for (spec, acc) in self.aggregate_specs.iter().zip(accumulators.iter_mut()) {
236 let value = &row[self.num_key_columns + spec.value_index];
237 acc.accumulate(value);
238 }
239 } else {
240 let accumulators = self.create_accumulators();
242 let group_memory = Partition::estimate_group_memory(key_values, &accumulators);
243
244 if !self.reservation.try_grow(group_memory) {
246 self.spill_largest_partition()?;
248
249 let partition = &mut self.partitions[partition_idx];
251 if partition.spilled {
252 self.spill_row_to_partition(partition_idx, row)?;
253 return Ok(());
254 }
255
256 if !self.reservation.try_grow(group_memory) {
258 return Err(io::Error::new(
259 io::ErrorKind::OutOfMemory,
260 "single group exceeds available memory budget",
261 ));
262 }
263 }
264
265 let mut accumulators = self.create_accumulators();
267 for (spec, acc) in self.aggregate_specs.iter().zip(accumulators.iter_mut()) {
268 let value = &row[self.num_key_columns + spec.value_index];
269 acc.accumulate(value);
270 }
271
272 let partition = &mut self.partitions[partition_idx];
273 partition.groups.insert(key_bytes, (key_values.to_vec(), accumulators));
274 partition.memory_bytes += group_memory;
275 self.total_groups += 1;
276 }
277
278 Ok(())
279 }
280
281 fn create_accumulators(&self) -> Vec<AggregateAccumulator> {
283 self.aggregate_specs
284 .iter()
285 .map(|spec| {
286 AggregateAccumulator::new(&spec.function_name, spec.distinct)
287 .expect("aggregate spec should be valid")
288 })
289 .collect()
290 }
291
292 fn compute_partition(&self, key_values: &[SqlValue]) -> usize {
294 use std::hash::Hasher;
295 let mut hasher = ahash::AHasher::default();
296 for v in key_values {
297 hash_sql_value(v, &mut hasher);
298 }
299 (hasher.finish() as usize) & self.partition_mask
300 }
301
302 fn spill_row_to_partition(&mut self, partition_idx: usize, row: &[SqlValue]) -> io::Result<()> {
304 let partition = &mut self.partitions[partition_idx];
305
306 if partition.spill_file.is_none() {
308 let temp_dir = self.reservation.temp_directory();
309 partition.spill_file =
310 Some(SpillFile::with_suffix(temp_dir, &format!("agg_part_{}", partition_idx))?);
311 }
312
313 let spill_file = partition.spill_file.as_mut().unwrap();
315 let mut buf = Vec::new();
316 let row_to_serialize = Row::from_vec(row.to_vec());
317 serialize_row(&row_to_serialize, &mut buf)?;
318
319 let len = buf.len() as u32;
321 spill_file.write_all(&len.to_le_bytes())?;
322 spill_file.write_all(&buf)?;
323
324 partition.rows_spilled += 1;
325 self.reservation.record_spill(buf.len() + 4);
326
327 Ok(())
328 }
329
330 fn spill_largest_partition(&mut self) -> io::Result<()> {
332 let (largest_idx, largest_size) = self
334 .partitions
335 .iter()
336 .enumerate()
337 .filter(|(_, p)| !p.spilled && !p.groups.is_empty())
338 .max_by_key(|(_, p)| p.memory_bytes)
339 .map(|(i, p)| (i, p.memory_bytes))
340 .ok_or_else(|| {
341 io::Error::new(io::ErrorKind::OutOfMemory, "no partition available to spill")
342 })?;
343
344 let temp_dir = self.reservation.temp_directory().clone();
346 let spill_file = SpillFile::with_suffix(&temp_dir, &format!("agg_part_{}", largest_idx))?;
347
348 let partition = &mut self.partitions[largest_idx];
352 let groups_lost = partition.groups.len();
353 partition.groups.clear();
354
355 self.reservation.shrink(largest_size);
357 self.total_groups = self.total_groups.saturating_sub(groups_lost);
358 partition.memory_bytes = 0;
359 partition.spilled = true;
360 partition.spill_file = Some(spill_file);
361
362 Ok(())
363 }
364
365 pub fn finish(mut self) -> io::Result<AggregateResultIterator> {
367 for partition in &mut self.partitions {
369 if let Some(ref mut file) = partition.spill_file {
370 file.flush()?;
371 }
372 }
373
374 let mut in_memory_results: Vec<(Vec<SqlValue>, Vec<SqlValue>)> = Vec::new();
376
377 for partition in &mut self.partitions {
378 if !partition.spilled {
379 for (_, (key_values, accumulators)) in partition.groups.drain() {
380 let agg_values: Vec<SqlValue> = accumulators
381 .iter()
382 .map(|a| {
383 a.finalize().map_err(|e| {
384 io::Error::new(io::ErrorKind::InvalidData, e.to_string())
385 })
386 })
387 .collect::<io::Result<Vec<_>>>()?;
388 in_memory_results.push((key_values, agg_values));
389 }
390 }
391 }
392
393 let spilled_partitions: Vec<_> = self
395 .partitions
396 .into_iter()
397 .enumerate()
398 .filter(|(_, p)| p.spilled && p.rows_spilled > 0)
399 .collect();
400
401 Ok(AggregateResultIterator {
402 in_memory_results: in_memory_results.into_iter(),
403 spilled_partitions,
404 current_spill_idx: 0,
405 aggregate_specs: self.aggregate_specs,
406 num_key_columns: self.num_key_columns,
407 partition_mask: self.partition_mask,
408 #[allow(dead_code)]
409 reservation: self.reservation,
410 })
411 }
412
413 pub fn num_groups(&self) -> usize {
415 self.total_groups
416 }
417
418 pub fn num_spilled_partitions(&self) -> usize {
420 self.partitions.iter().filter(|p| p.spilled).count()
421 }
422}
423
424pub struct AggregateResultIterator {
426 in_memory_results: std::vec::IntoIter<(Vec<SqlValue>, Vec<SqlValue>)>,
428
429 spilled_partitions: Vec<(usize, Partition)>,
431
432 current_spill_idx: usize,
434
435 aggregate_specs: Vec<AggregateSpec>,
437
438 num_key_columns: usize,
440
441 #[allow(dead_code)]
443 partition_mask: usize,
444
445 #[allow(dead_code)]
447 reservation: MemoryReservation,
448}
449
450impl Iterator for AggregateResultIterator {
451 type Item = io::Result<Vec<SqlValue>>;
452
453 fn next(&mut self) -> Option<Self::Item> {
454 if let Some((key_values, agg_values)) = self.in_memory_results.next() {
456 let mut result = key_values;
457 result.extend(agg_values);
458 return Some(Ok(result));
459 }
460
461 while self.current_spill_idx < self.spilled_partitions.len() {
463 let result = self.process_next_spilled_partition();
464 if result.is_some() {
465 return result;
466 }
467 self.current_spill_idx += 1;
468 }
469
470 None
471 }
472}
473
474impl AggregateResultIterator {
475 fn process_next_spilled_partition(&mut self) -> Option<io::Result<Vec<SqlValue>>> {
477 if self.current_spill_idx >= self.spilled_partitions.len() {
478 return None;
479 }
480
481 let (_, partition) = &mut self.spilled_partitions[self.current_spill_idx];
483 let rows_spilled = partition.rows_spilled;
484
485 let spill_file_opt = partition.spill_file.take();
487 partition.rows_spilled = 0;
488
489 if let Some(mut spill_file) = spill_file_opt {
490 match self.reload_and_aggregate_partition(&mut spill_file, rows_spilled) {
492 Ok(mut results) => {
493 if let Some((key_values, agg_values)) = results.pop() {
495 if !results.is_empty() {
497 self.in_memory_results = results.into_iter();
498 }
499
500 let mut result = key_values;
501 result.extend(agg_values);
502 return Some(Ok(result));
503 }
504 }
505 Err(e) => return Some(Err(e)),
506 }
507 }
508
509 None
510 }
511
512 fn reload_and_aggregate_partition(
514 &self,
515 spill_file: &mut SpillFile,
516 num_rows: usize,
517 ) -> io::Result<Vec<(Vec<SqlValue>, Vec<SqlValue>)>> {
518 spill_file.prepare_for_read()?;
519
520 let mut groups: HashMap<Vec<u8>, (Vec<SqlValue>, Vec<AggregateAccumulator>)> =
522 HashMap::new();
523
524 for _ in 0..num_rows {
525 let mut len_buf = [0u8; 4];
527 spill_file.read_exact(&mut len_buf)?;
528 let len = u32::from_le_bytes(len_buf) as usize;
529
530 let mut row_buf = vec![0u8; len];
531 spill_file.read_exact(&mut row_buf)?;
532
533 let row = deserialize_row(&mut Cursor::new(row_buf))?;
534 let row_values: Vec<SqlValue> = row.values.into_iter().collect();
535
536 let key_values = &row_values[..self.num_key_columns];
538 let key_bytes = serialize_key(key_values);
539
540 if let Some((_, accumulators)) = groups.get_mut(&key_bytes) {
541 for (spec, acc) in self.aggregate_specs.iter().zip(accumulators.iter_mut()) {
542 let value = &row_values[self.num_key_columns + spec.value_index];
543 acc.accumulate(value);
544 }
545 } else {
546 let mut accumulators: Vec<AggregateAccumulator> = self
547 .aggregate_specs
548 .iter()
549 .map(|spec| {
550 AggregateAccumulator::new(&spec.function_name, spec.distinct)
551 .expect("valid spec")
552 })
553 .collect();
554
555 for (spec, acc) in self.aggregate_specs.iter().zip(accumulators.iter_mut()) {
556 let value = &row_values[self.num_key_columns + spec.value_index];
557 acc.accumulate(value);
558 }
559
560 groups.insert(key_bytes, (key_values.to_vec(), accumulators));
561 }
562 }
563
564 let mut results = Vec::with_capacity(groups.len());
566 for (key_values, accumulators) in groups.into_values() {
567 let agg_values: Vec<SqlValue> = accumulators
568 .iter()
569 .map(|a| {
570 a.finalize()
571 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e.to_string()))
572 })
573 .collect::<io::Result<Vec<_>>>()?;
574 results.push((key_values, agg_values));
575 }
576
577 Ok(results)
578 }
579}
580
581fn serialize_key(values: &[SqlValue]) -> Vec<u8> {
583 use super::row_serialization::serialize_value;
584 let mut buf = Vec::new();
585 for v in values {
586 serialize_value(v, &mut buf).expect("key serialization should not fail");
587 }
588 buf
589}
590
591fn hash_sql_value<H: std::hash::Hasher>(value: &SqlValue, hasher: &mut H) {
593 use std::hash::Hash;
594 std::mem::discriminant(value).hash(hasher);
595 match value {
596 SqlValue::Null => {}
597 SqlValue::Boolean(b) => b.hash(hasher),
598 SqlValue::Smallint(i) => i.hash(hasher),
599 SqlValue::Integer(i) => i.hash(hasher),
600 SqlValue::Bigint(i) => i.hash(hasher),
601 SqlValue::Unsigned(u) => u.hash(hasher),
602 SqlValue::Float(f) => f.to_bits().hash(hasher),
603 SqlValue::Real(f) => f.to_bits().hash(hasher),
604 SqlValue::Double(f) => f.to_bits().hash(hasher),
605 SqlValue::Numeric(f) => f.to_bits().hash(hasher),
606 SqlValue::Character(s) | SqlValue::Varchar(s) => s.hash(hasher),
607 SqlValue::Date(d) => (d.year, d.month, d.day).hash(hasher),
608 SqlValue::Time(t) => (t.hour, t.minute, t.second, t.nanosecond).hash(hasher),
609 SqlValue::Timestamp(ts) => {
610 (ts.date.year, ts.date.month, ts.date.day).hash(hasher);
611 (ts.time.hour, ts.time.minute, ts.time.second, ts.time.nanosecond).hash(hasher);
612 }
613 SqlValue::Interval(iv) => {
614 iv.value.hash(hasher);
616 }
617 SqlValue::Vector(v) => {
618 for f in v {
619 f.to_bits().hash(hasher);
620 }
621 }
622 SqlValue::Blob(b) => b.hash(hasher),
623 }
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 fn make_test_controller() -> Arc<MemoryController> {
631 Arc::new(MemoryController::with_budget(1024 * 1024)) }
633
634 #[test]
635 fn test_simple_count() {
636 let controller = make_test_controller();
637 let specs = vec![AggregateSpec {
638 function_name: "COUNT".to_string(),
639 distinct: false,
640 value_index: 0,
641 }];
642
643 let mut agg = ExternalAggregate::new(&controller, 1, specs);
644
645 agg.add_row(&[SqlValue::Varchar("a".into()), SqlValue::Integer(1)]).unwrap();
648 agg.add_row(&[SqlValue::Varchar("a".into()), SqlValue::Integer(2)]).unwrap();
649 agg.add_row(&[SqlValue::Varchar("a".into()), SqlValue::Integer(3)]).unwrap();
650
651 agg.add_row(&[SqlValue::Varchar("b".into()), SqlValue::Integer(10)]).unwrap();
653 agg.add_row(&[SqlValue::Varchar("b".into()), SqlValue::Integer(20)]).unwrap();
654
655 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
656
657 assert_eq!(results.len(), 2);
658
659 let group_a = results.iter().find(|r| r[0] == SqlValue::Varchar("a".into())).unwrap();
661 let group_b = results.iter().find(|r| r[0] == SqlValue::Varchar("b".into())).unwrap();
662
663 assert_eq!(group_a[1], SqlValue::Integer(3)); assert_eq!(group_b[1], SqlValue::Integer(2)); }
666
667 #[test]
668 fn test_sum_and_avg() {
669 let controller = make_test_controller();
670 let specs = vec![
671 AggregateSpec { function_name: "SUM".to_string(), distinct: false, value_index: 0 },
672 AggregateSpec { function_name: "AVG".to_string(), distinct: false, value_index: 0 },
673 ];
674
675 let mut agg = ExternalAggregate::new(&controller, 1, specs);
676
677 agg.add_row(&[SqlValue::Integer(1), SqlValue::Integer(10)]).unwrap();
679 agg.add_row(&[SqlValue::Integer(1), SqlValue::Integer(20)]).unwrap();
680 agg.add_row(&[SqlValue::Integer(1), SqlValue::Integer(30)]).unwrap();
681
682 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
683
684 assert_eq!(results.len(), 1);
685 assert_eq!(results[0][0], SqlValue::Integer(1)); assert_eq!(results[0][1], SqlValue::Integer(60)); assert_eq!(results[0][2], SqlValue::Double(20.0)); }
689
690 #[test]
691 fn test_min_max() {
692 let controller = make_test_controller();
693 let specs = vec![
694 AggregateSpec { function_name: "MIN".to_string(), distinct: false, value_index: 0 },
695 AggregateSpec { function_name: "MAX".to_string(), distinct: false, value_index: 0 },
696 ];
697
698 let mut agg = ExternalAggregate::new(&controller, 1, specs);
699
700 agg.add_row(&[SqlValue::Varchar("x".into()), SqlValue::Integer(5)]).unwrap();
702 agg.add_row(&[SqlValue::Varchar("x".into()), SqlValue::Integer(15)]).unwrap();
703 agg.add_row(&[SqlValue::Varchar("x".into()), SqlValue::Integer(10)]).unwrap();
704
705 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
706
707 assert_eq!(results.len(), 1);
708 assert_eq!(results[0][1], SqlValue::Integer(5)); assert_eq!(results[0][2], SqlValue::Integer(15)); }
711
712 #[test]
713 fn test_multi_key_grouping() {
714 let controller = make_test_controller();
715 let specs = vec![AggregateSpec {
716 function_name: "COUNT".to_string(),
717 distinct: false,
718 value_index: 0,
719 }];
720
721 let mut agg = ExternalAggregate::new(&controller, 2, specs); agg.add_row(&[SqlValue::Integer(1), SqlValue::Varchar("a".into()), SqlValue::Integer(100)])
725 .unwrap();
726 agg.add_row(&[SqlValue::Integer(1), SqlValue::Varchar("a".into()), SqlValue::Integer(200)])
727 .unwrap();
728
729 agg.add_row(&[SqlValue::Integer(1), SqlValue::Varchar("b".into()), SqlValue::Integer(300)])
731 .unwrap();
732
733 agg.add_row(&[SqlValue::Integer(2), SqlValue::Varchar("a".into()), SqlValue::Integer(400)])
735 .unwrap();
736 agg.add_row(&[SqlValue::Integer(2), SqlValue::Varchar("a".into()), SqlValue::Integer(500)])
737 .unwrap();
738 agg.add_row(&[SqlValue::Integer(2), SqlValue::Varchar("a".into()), SqlValue::Integer(600)])
739 .unwrap();
740
741 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
742
743 assert_eq!(results.len(), 3);
744
745 let g1a = results
747 .iter()
748 .find(|r| r[0] == SqlValue::Integer(1) && r[1] == SqlValue::Varchar("a".into()))
749 .unwrap();
750 let g1b = results
751 .iter()
752 .find(|r| r[0] == SqlValue::Integer(1) && r[1] == SqlValue::Varchar("b".into()))
753 .unwrap();
754 let g2a = results
755 .iter()
756 .find(|r| r[0] == SqlValue::Integer(2) && r[1] == SqlValue::Varchar("a".into()))
757 .unwrap();
758
759 assert_eq!(g1a[2], SqlValue::Integer(2)); assert_eq!(g1b[2], SqlValue::Integer(1)); assert_eq!(g2a[2], SqlValue::Integer(3)); }
763
764 #[test]
765 fn test_null_handling() {
766 let controller = make_test_controller();
767 let specs = vec![
768 AggregateSpec { function_name: "COUNT".to_string(), distinct: false, value_index: 0 },
769 AggregateSpec { function_name: "SUM".to_string(), distinct: false, value_index: 0 },
770 ];
771
772 let mut agg = ExternalAggregate::new(&controller, 1, specs);
773
774 agg.add_row(&[SqlValue::Integer(1), SqlValue::Integer(10)]).unwrap();
776 agg.add_row(&[SqlValue::Integer(1), SqlValue::Null]).unwrap();
777 agg.add_row(&[SqlValue::Integer(1), SqlValue::Integer(20)]).unwrap();
778 agg.add_row(&[SqlValue::Integer(1), SqlValue::Null]).unwrap();
779
780 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
781
782 assert_eq!(results.len(), 1);
783 assert_eq!(results[0][1], SqlValue::Integer(2)); assert_eq!(results[0][2], SqlValue::Integer(30));
786 }
787
788 #[test]
789 fn test_empty_input() {
790 let controller = make_test_controller();
791 let specs = vec![AggregateSpec {
792 function_name: "COUNT".to_string(),
793 distinct: false,
794 value_index: 0,
795 }];
796
797 let agg = ExternalAggregate::new(&controller, 1, specs);
798 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
799
800 assert!(results.is_empty());
801 }
802
803 #[test]
804 fn test_spill_under_memory_pressure() {
805 let controller = Arc::new(MemoryController::with_budget(4096)); let config = ExternalAggregateConfig {
808 num_partitions: 4, max_groups_per_partition: 10,
810 };
811
812 let specs = vec![AggregateSpec {
813 function_name: "SUM".to_string(),
814 distinct: false,
815 value_index: 0,
816 }];
817
818 let mut agg = ExternalAggregate::with_config(&controller, 1, specs, config);
819
820 for i in 0..100 {
822 agg.add_row(&[SqlValue::Integer(i), SqlValue::Integer(i * 10)]).unwrap();
823 }
824
825 let _ = agg.num_spilled_partitions();
827
828 let results: Vec<_> = agg.finish().unwrap().map(|r| r.unwrap()).collect();
829
830 assert!(results.len() <= 100);
834 }
835}