1use std::{
41 cmp::Ordering,
42 collections::BinaryHeap,
43 io::{self, Cursor},
44 sync::Arc,
45};
46
47use vibesql_ast::OrderDirection;
48use vibesql_storage::Row;
49use vibesql_types::SqlValue;
50
51use super::{
52 row_serialization::{deserialize_row_with_keys, serialize_row_with_keys},
53 MemoryController, MemoryReservation, SpillFile,
54};
55
56pub type SortKey = Vec<(SqlValue, OrderDirection)>;
58
59pub type RowWithKeys = (Row, SortKey);
61
62#[derive(Debug, Clone)]
64pub struct ExternalSortConfig {
65 pub max_run_size: usize,
67
68 pub merge_fanout: usize,
70}
71
72impl Default for ExternalSortConfig {
73 fn default() -> Self {
74 Self {
75 max_run_size: 50_000, merge_fanout: 16, }
78 }
79}
80
81enum SortRun {
83 #[allow(dead_code)]
85 InMemory(Vec<RowWithKeys>),
86
87 OnDisk { file: SpillFile, row_count: usize },
89}
90
91impl SortRun {
92 fn row_count(&self) -> usize {
94 match self {
95 SortRun::InMemory(rows) => rows.len(),
96 SortRun::OnDisk { row_count, .. } => *row_count,
97 }
98 }
99}
100
101pub struct ExternalSort {
105 reservation: MemoryReservation,
107
108 config: ExternalSortConfig,
110
111 buffer: Vec<RowWithKeys>,
113
114 buffer_memory: usize,
116
117 runs: Vec<SortRun>,
119
120 comparator: SortKeyComparator,
122}
123
124#[derive(Clone)]
126struct SortKeyComparator;
127
128impl SortKeyComparator {
129 fn compare(&self, a: &SortKey, b: &SortKey) -> Ordering {
131 for ((val_a, dir), (val_b, _)) in a.iter().zip(b.iter()) {
132 let cmp = match (val_a.is_null(), val_b.is_null()) {
134 (true, true) => Ordering::Equal,
135 (true, false) => return Ordering::Greater, (false, true) => return Ordering::Less, (false, false) => {
138 let base_cmp = compare_sql_values(val_a, val_b);
140 match dir {
141 OrderDirection::Asc => base_cmp,
142 OrderDirection::Desc => base_cmp.reverse(),
143 }
144 }
145 };
146
147 if cmp != Ordering::Equal {
148 return cmp;
149 }
150 }
151 Ordering::Equal
152 }
153
154 fn compare_with_rowid(
160 &self,
161 a: &RowWithKeys,
162 b: &RowWithKeys,
163 ) -> Ordering {
164 let key_cmp = self.compare(&a.1, &b.1);
165 if key_cmp != Ordering::Equal {
166 return key_cmp;
167 }
168 a.0.row_id.cmp(&b.0.row_id)
170 }
171}
172
173fn compare_sql_values(a: &SqlValue, b: &SqlValue) -> Ordering {
175 use SqlValue::*;
176 match (a, b) {
177 (Integer(x), Integer(y)) => x.cmp(y),
178 (Smallint(x), Smallint(y)) => x.cmp(y),
179 (Bigint(x), Bigint(y)) => x.cmp(y),
180 (Unsigned(x), Unsigned(y)) => x.cmp(y),
181 (Float(x), Float(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
182 (Real(x), Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
183 (Double(x), Double(y)) | (Numeric(x), Numeric(y)) => {
184 x.partial_cmp(y).unwrap_or(Ordering::Equal)
185 }
186 (Character(x), Character(y)) | (Varchar(x), Varchar(y)) => x.cmp(y),
187 (Character(x), Varchar(y)) | (Varchar(x), Character(y)) => x.as_str().cmp(y.as_str()),
188 (Boolean(x), Boolean(y)) => x.cmp(y),
189 (Date(x), Date(y)) => x.cmp(y),
190 (Time(x), Time(y)) => x.cmp(y),
191 (Timestamp(x), Timestamp(y)) => x.cmp(y),
192 (Interval(x), Interval(y)) => x.cmp(y),
193 _ => Ordering::Equal, }
195}
196
197impl ExternalSort {
198 pub fn new(controller: &Arc<MemoryController>) -> Self {
200 Self::with_config(controller, ExternalSortConfig::default())
201 }
202
203 pub fn with_config(controller: &Arc<MemoryController>, config: ExternalSortConfig) -> Self {
205 Self {
206 reservation: controller.create_reservation(),
207 config,
208 buffer: Vec::new(),
209 buffer_memory: 0,
210 runs: Vec::new(),
211 comparator: SortKeyComparator,
212 }
213 }
214
215 pub fn add_row(&mut self, row: Row, sort_keys: SortKey) -> io::Result<()> {
219 let row_memory = row.estimated_size_bytes()
221 + sort_keys.iter().map(|(v, _)| v.estimated_size_bytes()).sum::<usize>()
222 + std::mem::size_of::<RowWithKeys>();
223
224 if !self.reservation.try_grow(row_memory) {
226 self.spill_buffer()?;
228
229 if !self.reservation.try_grow(row_memory) {
231 return Err(io::Error::new(
234 io::ErrorKind::OutOfMemory,
235 "single row exceeds available memory budget",
236 ));
237 }
238 }
239
240 self.buffer.push((row, sort_keys));
241 self.buffer_memory += row_memory;
242
243 if self.buffer.len() >= self.config.max_run_size {
245 self.spill_buffer()?;
246 }
247
248 Ok(())
249 }
250
251 fn spill_buffer(&mut self) -> io::Result<()> {
253 if self.buffer.is_empty() {
254 return Ok(());
255 }
256
257 let comparator = self.comparator.clone();
259 self.buffer.sort_by(|a, b| comparator.compare_with_rowid(a, b));
260
261 let temp_dir = self.reservation.temp_directory();
263 let mut spill_file = SpillFile::with_suffix(temp_dir, "sort_run")?;
264
265 let row_count = self.buffer.len();
267 let mut spill_buf = Vec::new();
268
269 for (row, keys) in &self.buffer {
270 serialize_row_with_keys(row, keys, &mut spill_buf)?;
271 }
272
273 spill_file.write_all(&spill_buf)?;
274 spill_file.flush()?;
275
276 self.reservation.record_spill(spill_buf.len());
278
279 self.reservation.shrink(self.buffer_memory);
281 self.buffer.clear();
282 self.buffer_memory = 0;
283
284 self.runs.push(SortRun::OnDisk { file: spill_file, row_count });
286
287 Ok(())
288 }
289
290 pub fn finish(mut self) -> io::Result<SortedIterator> {
294 if !self.buffer.is_empty() {
296 let comparator = self.comparator.clone();
297 self.buffer.sort_by(|a, b| comparator.compare_with_rowid(a, b));
298
299 if self.runs.is_empty() {
301 return Ok(SortedIterator::InMemory(InMemoryIterator::new(
303 self.buffer,
304 self.reservation,
305 )));
306 }
307
308 self.spill_buffer()?;
310 }
311
312 if self.runs.is_empty() {
314 return Ok(SortedIterator::InMemory(InMemoryIterator::new(
315 Vec::new(),
316 self.reservation,
317 )));
318 }
319
320 Ok(SortedIterator::Merge(MergeIterator::new(self.runs, self.comparator, self.reservation)?))
322 }
323
324 pub fn run_count(&self) -> usize {
326 self.runs.len() + if self.buffer.is_empty() { 0 } else { 1 }
327 }
328
329 pub fn total_rows(&self) -> usize {
331 self.runs.iter().map(|r| r.row_count()).sum::<usize>() + self.buffer.len()
332 }
333}
334
335pub enum SortedIterator {
337 InMemory(InMemoryIterator),
339
340 Merge(MergeIterator),
342}
343
344impl Iterator for SortedIterator {
345 type Item = io::Result<Row>;
346
347 fn next(&mut self) -> Option<Self::Item> {
348 match self {
349 SortedIterator::InMemory(iter) => iter.next(),
350 SortedIterator::Merge(iter) => iter.next(),
351 }
352 }
353}
354
355pub struct InMemoryIterator {
357 rows: std::vec::IntoIter<RowWithKeys>,
358 #[allow(dead_code)]
359 reservation: MemoryReservation, }
361
362impl InMemoryIterator {
363 fn new(rows: Vec<RowWithKeys>, reservation: MemoryReservation) -> Self {
364 Self { rows: rows.into_iter(), reservation }
365 }
366}
367
368impl Iterator for InMemoryIterator {
369 type Item = io::Result<Row>;
370
371 fn next(&mut self) -> Option<Self::Item> {
372 self.rows.next().map(|(row, _keys)| Ok(row))
373 }
374}
375
376pub struct MergeIterator {
378 heap: BinaryHeap<MergeEntry>,
380
381 readers: Vec<RunReader>,
383
384 comparator: SortKeyComparator,
386
387 #[allow(dead_code)]
389 reservation: MemoryReservation,
390}
391
392struct MergeEntry {
394 row: Row,
396
397 key: SortKey,
399
400 run_index: usize,
402
403 comparator: SortKeyComparator,
405}
406
407impl PartialEq for MergeEntry {
408 fn eq(&self, other: &Self) -> bool {
409 let key_cmp = self.comparator.compare(&self.key, &other.key);
411 if key_cmp != Ordering::Equal {
412 return false;
413 }
414 self.row.row_id == other.row.row_id
415 }
416}
417
418impl Eq for MergeEntry {}
419
420impl PartialOrd for MergeEntry {
421 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
422 Some(self.cmp(other))
423 }
424}
425
426impl Ord for MergeEntry {
427 fn cmp(&self, other: &Self) -> Ordering {
428 let key_cmp = self.comparator.compare(&other.key, &self.key);
431 if key_cmp != Ordering::Equal {
432 return key_cmp;
433 }
434 other.row.row_id.cmp(&self.row.row_id)
437 }
438}
439
440struct RunReader {
442 cursor: Cursor<Vec<u8>>,
444
445 remaining: usize,
447}
448
449impl RunReader {
450 fn new(mut file: SpillFile, row_count: usize) -> io::Result<Self> {
451 file.prepare_for_read()?;
452 let data = file.read_to_vec()?;
453 Ok(Self { cursor: Cursor::new(data), remaining: row_count })
454 }
455
456 fn read_next(&mut self) -> io::Result<Option<RowWithKeys>> {
457 if self.remaining == 0 {
458 return Ok(None);
459 }
460 self.remaining -= 1;
461 let (row, keys) = deserialize_row_with_keys(&mut self.cursor)?;
462 Ok(Some((row, keys)))
463 }
464}
465
466impl MergeIterator {
467 fn new(
468 runs: Vec<SortRun>,
469 comparator: SortKeyComparator,
470 reservation: MemoryReservation,
471 ) -> io::Result<Self> {
472 let mut readers = Vec::with_capacity(runs.len());
473 let mut heap = BinaryHeap::new();
474
475 for (run_index, run) in runs.into_iter().enumerate() {
476 match run {
477 SortRun::InMemory(rows) => {
478 for (row, key) in rows {
480 heap.push(MergeEntry {
481 row,
482 key,
483 run_index,
484 comparator: comparator.clone(),
485 });
486 }
487 readers.push(RunReader { cursor: Cursor::new(Vec::new()), remaining: 0 });
488 }
489 SortRun::OnDisk { file, row_count } => {
490 let mut reader = RunReader::new(file, row_count)?;
491
492 if let Some((row, key)) = reader.read_next()? {
494 heap.push(MergeEntry {
495 row,
496 key,
497 run_index,
498 comparator: comparator.clone(),
499 });
500 }
501
502 readers.push(reader);
503 }
504 }
505 }
506
507 Ok(Self { heap, readers, comparator, reservation })
508 }
509}
510
511impl Iterator for MergeIterator {
512 type Item = io::Result<Row>;
513
514 fn next(&mut self) -> Option<Self::Item> {
515 let entry = self.heap.pop()?;
517
518 match self.readers[entry.run_index].read_next() {
520 Ok(Some((row, key))) => {
521 self.heap.push(MergeEntry {
522 row,
523 key,
524 run_index: entry.run_index,
525 comparator: self.comparator.clone(),
526 });
527 }
528 Ok(None) => {
529 }
531 Err(e) => {
532 return Some(Err(e));
533 }
534 }
535
536 Some(Ok(entry.row))
537 }
538}
539
540#[cfg(test)]
541mod tests {
542 use super::*;
543
544 fn make_test_controller() -> Arc<MemoryController> {
545 Arc::new(MemoryController::with_budget(1024 * 1024)) }
547
548 #[test]
549 fn test_in_memory_sort() {
550 let controller = make_test_controller();
551 let mut sorter = ExternalSort::new(&controller);
552
553 for i in (0..100).rev() {
555 let row = Row::from_vec(vec![SqlValue::Integer(i)]);
556 let keys = vec![(SqlValue::Integer(i), OrderDirection::Asc)];
557 sorter.add_row(row, keys).unwrap();
558 }
559
560 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
562
563 assert_eq!(results.len(), 100);
564 for (i, row) in results.iter().enumerate() {
565 assert_eq!(row.values[0], SqlValue::Integer(i as i64));
566 }
567 }
568
569 #[test]
570 fn test_descending_sort() {
571 let controller = make_test_controller();
572 let mut sorter = ExternalSort::new(&controller);
573
574 for i in 0..50 {
576 let row = Row::from_vec(vec![SqlValue::Integer(i)]);
577 let keys = vec![(SqlValue::Integer(i), OrderDirection::Desc)];
578 sorter.add_row(row, keys).unwrap();
579 }
580
581 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
582
583 assert_eq!(results.len(), 50);
584 for (i, row) in results.iter().enumerate() {
586 assert_eq!(row.values[0], SqlValue::Integer(49 - i as i64));
587 }
588 }
589
590 #[test]
591 fn test_null_handling() {
592 let controller = make_test_controller();
593 let mut sorter = ExternalSort::new(&controller);
594
595 sorter
597 .add_row(
598 Row::from_vec(vec![SqlValue::Integer(2)]),
599 vec![(SqlValue::Integer(2), OrderDirection::Asc)],
600 )
601 .unwrap();
602 sorter
603 .add_row(
604 Row::from_vec(vec![SqlValue::Null]),
605 vec![(SqlValue::Null, OrderDirection::Asc)],
606 )
607 .unwrap();
608 sorter
609 .add_row(
610 Row::from_vec(vec![SqlValue::Integer(1)]),
611 vec![(SqlValue::Integer(1), OrderDirection::Asc)],
612 )
613 .unwrap();
614
615 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
616
617 assert_eq!(results[0].values[0], SqlValue::Integer(1));
619 assert_eq!(results[1].values[0], SqlValue::Integer(2));
620 assert_eq!(results[2].values[0], SqlValue::Null);
621 }
622
623 #[test]
624 fn test_multi_key_sort() {
625 let controller = make_test_controller();
626 let mut sorter = ExternalSort::new(&controller);
627
628 sorter
631 .add_row(
632 Row::from_vec(vec![SqlValue::Integer(1), SqlValue::Integer(2)]),
633 vec![
634 (SqlValue::Integer(1), OrderDirection::Asc),
635 (SqlValue::Integer(2), OrderDirection::Asc),
636 ],
637 )
638 .unwrap();
639 sorter
640 .add_row(
641 Row::from_vec(vec![SqlValue::Integer(1), SqlValue::Integer(1)]),
642 vec![
643 (SqlValue::Integer(1), OrderDirection::Asc),
644 (SqlValue::Integer(1), OrderDirection::Asc),
645 ],
646 )
647 .unwrap();
648 sorter
649 .add_row(
650 Row::from_vec(vec![SqlValue::Integer(2), SqlValue::Integer(1)]),
651 vec![
652 (SqlValue::Integer(2), OrderDirection::Asc),
653 (SqlValue::Integer(1), OrderDirection::Asc),
654 ],
655 )
656 .unwrap();
657
658 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
659
660 assert_eq!(results[0].values[0], SqlValue::Integer(1));
662 assert_eq!(results[0].values[1], SqlValue::Integer(1));
663 assert_eq!(results[1].values[0], SqlValue::Integer(1));
664 assert_eq!(results[1].values[1], SqlValue::Integer(2));
665 assert_eq!(results[2].values[0], SqlValue::Integer(2));
666 assert_eq!(results[2].values[1], SqlValue::Integer(1));
667 }
668
669 #[test]
670 fn test_spill_and_merge() {
671 let controller = Arc::new(MemoryController::with_budget(4096)); let config = ExternalSortConfig {
674 max_run_size: 10, ..Default::default()
676 };
677
678 let mut sorter = ExternalSort::with_config(&controller, config);
679
680 for i in (0..50).rev() {
682 let row = Row::from_vec(vec![SqlValue::Integer(i)]);
683 let keys = vec![(SqlValue::Integer(i), OrderDirection::Asc)];
684 sorter.add_row(row, keys).unwrap();
685 }
686
687 assert!(sorter.run_count() > 1, "Expected multiple runs from spilling");
689
690 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
692
693 assert_eq!(results.len(), 50);
694 for (i, row) in results.iter().enumerate() {
695 assert_eq!(row.values[0], SqlValue::Integer(i as i64));
696 }
697 }
698
699 #[test]
700 fn test_empty_sort() {
701 let controller = make_test_controller();
702 let sorter = ExternalSort::new(&controller);
703
704 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
705 assert!(results.is_empty());
706 }
707
708 #[test]
709 fn test_string_sort() {
710 let controller = make_test_controller();
711 let mut sorter = ExternalSort::new(&controller);
712
713 sorter
714 .add_row(
715 Row::from_vec(vec![SqlValue::Varchar("charlie".into())]),
716 vec![(SqlValue::Varchar("charlie".into()), OrderDirection::Asc)],
717 )
718 .unwrap();
719 sorter
720 .add_row(
721 Row::from_vec(vec![SqlValue::Varchar("alpha".into())]),
722 vec![(SqlValue::Varchar("alpha".into()), OrderDirection::Asc)],
723 )
724 .unwrap();
725 sorter
726 .add_row(
727 Row::from_vec(vec![SqlValue::Varchar("bravo".into())]),
728 vec![(SqlValue::Varchar("bravo".into()), OrderDirection::Asc)],
729 )
730 .unwrap();
731
732 let results: Vec<Row> = sorter.finish().unwrap().map(|r| r.unwrap()).collect();
733
734 assert_eq!(results[0].values[0], SqlValue::Varchar("alpha".into()));
735 assert_eq!(results[1].values[0], SqlValue::Varchar("bravo".into()));
736 assert_eq!(results[2].values[0], SqlValue::Varchar("charlie".into()));
737 }
738}