Skip to main content

sparrowdb_execution/
sort_spill.rs

1//! Spill-to-disk sort for ORDER BY on large result sets.
2//!
3//! Implements a k-way external merge sort:
4//! - Rows are buffered in memory up to `row_threshold` rows OR `byte_threshold` bytes.
5//! - When either threshold is exceeded, the in-memory buffer is sorted and written to
6//!   a `NamedTempFile` as a sorted run.
7//! - `finish()` merges all sorted runs (plus any remaining in-memory rows) using a
8//!   binary-heap-based k-way merge, returning a single sorted iterator.
9//!
10//! Row type `T` must implement `serde::Serialize + serde::de::DeserializeOwned + Ord`.
11//!
12//! SPA-113
13
14use std::cmp::Reverse;
15use std::collections::BinaryHeap;
16use std::io::{BufReader, BufWriter, Read, Write};
17
18use serde::{de::DeserializeOwned, Deserialize, Serialize};
19use sparrowdb_common::{Error, Result};
20use tempfile::NamedTempFile;
21
22/// Default in-memory row threshold before spilling.
23pub const DEFAULT_ROW_THRESHOLD: usize = 100_000;
24
25/// Default in-memory byte threshold (64 MiB) before spilling.
26pub const DEFAULT_BYTE_THRESHOLD: usize = 64 * 1024 * 1024;
27
28// ---------------------------------------------------------------------------
29// SpillingSorter
30// ---------------------------------------------------------------------------
31
32/// A sort operator that buffers rows in memory and spills sorted runs to disk
33/// when either `row_threshold` or `byte_threshold` is exceeded.
34pub struct SpillingSorter<T> {
35    /// In-memory row buffer.
36    buffer: Vec<T>,
37    /// Sorted run temp files (each file holds a contiguous sorted sequence).
38    runs: Vec<NamedTempFile>,
39    /// Maximum number of rows to hold in memory before spilling.
40    row_threshold: usize,
41    /// Maximum estimated in-memory bytes before spilling.
42    byte_threshold: usize,
43    /// Rough estimate of current in-memory bytes.
44    byte_estimate: usize,
45    /// Bytes per row estimate (seed value; refined as rows arrive).
46    bytes_per_row: usize,
47}
48
49impl<T> SpillingSorter<T>
50where
51    T: Serialize + DeserializeOwned + Ord + Clone,
52{
53    /// Create a new `SpillingSorter` with default thresholds.
54    pub fn new() -> Self {
55        SpillingSorter::with_thresholds(DEFAULT_ROW_THRESHOLD, DEFAULT_BYTE_THRESHOLD)
56    }
57
58    /// Create with explicit thresholds (useful for testing spill behaviour with
59    /// a small threshold).
60    pub fn with_thresholds(row_threshold: usize, byte_threshold: usize) -> Self {
61        SpillingSorter {
62            buffer: Vec::new(),
63            runs: Vec::new(),
64            row_threshold,
65            byte_threshold,
66            byte_estimate: 0,
67            bytes_per_row: 64, // initial guess
68        }
69    }
70
71    /// Push a single row.  Spills the in-memory buffer if a threshold is
72    /// exceeded after the push.
73    pub fn push(&mut self, row: T) -> Result<()> {
74        self.byte_estimate += self.bytes_per_row;
75        self.buffer.push(row);
76
77        if self.buffer.len() >= self.row_threshold || self.byte_estimate >= self.byte_threshold {
78            self.spill()?;
79        }
80        Ok(())
81    }
82
83    /// Sort and merge all data, returning a sorted iterator over every row
84    /// that was pushed.
85    pub fn finish(mut self) -> Result<impl Iterator<Item = T>> {
86        if self.runs.is_empty() {
87            // No spill happened — sort in memory and return a plain iterator.
88            self.buffer.sort();
89            return Ok(SortedOutput::Memory(self.buffer.into_iter()));
90        }
91
92        // Spill any remaining in-memory rows as a final sorted run.
93        if !self.buffer.is_empty() {
94            self.spill()?;
95        }
96
97        // K-way merge using a min-heap.
98        let mut readers: Vec<RunReader<T>> = self
99            .runs
100            .into_iter()
101            .map(RunReader::new)
102            .collect::<Result<Vec<_>>>()?;
103
104        // Seed the heap.
105        let mut heap: BinaryHeap<HeapEntry<T>> = BinaryHeap::new();
106        for (idx, reader) in readers.iter_mut().enumerate() {
107            if let Some(row) = reader.next_row()? {
108                heap.push(HeapEntry {
109                    row: Reverse(row),
110                    run_idx: idx,
111                });
112            }
113        }
114
115        Ok(SortedOutput::Merge(MergeIter {
116            heap,
117            readers,
118            exhausted: false,
119        }))
120    }
121
122    // ── Private helpers ───────────────────────────────────────────────────
123
124    /// Sort the in-memory buffer and write it to a new temp file as a run.
125    fn spill(&mut self) -> Result<()> {
126        self.buffer.sort();
127
128        // Refine the bytes-per-row estimate from actual serialized size.
129        // We serialize a sample (the first row) to get a real estimate.
130        if let Some(first) = self.buffer.first() {
131            if let Ok(encoded) = bincode::serialize(first) {
132                // length-prefix (8 bytes varint-style) + payload
133                self.bytes_per_row = encoded.len() + 8;
134            }
135        }
136
137        let mut tmp = NamedTempFile::new().map_err(Error::Io)?;
138        {
139            let mut writer = BufWriter::new(tmp.as_file_mut());
140            for row in &self.buffer {
141                write_row(&mut writer, row)?;
142            }
143            writer.flush().map_err(Error::Io)?;
144        }
145
146        self.runs.push(tmp);
147        self.buffer.clear();
148        self.byte_estimate = 0;
149        Ok(())
150    }
151}
152
153impl<T> Default for SpillingSorter<T>
154where
155    T: Serialize + DeserializeOwned + Ord + Clone,
156{
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162// ---------------------------------------------------------------------------
163// Row serialisation helpers
164// ---------------------------------------------------------------------------
165
166/// Write a length-prefixed bincode frame.
167fn write_row<W: Write, T: Serialize>(writer: &mut W, row: &T) -> Result<()> {
168    let encoded = bincode::serialize(row)
169        .map_err(|e| Error::InvalidArgument(format!("bincode encode: {e}")))?;
170    let len = encoded.len() as u64;
171    writer.write_all(&len.to_le_bytes()).map_err(Error::Io)?;
172    writer.write_all(&encoded).map_err(Error::Io)?;
173    Ok(())
174}
175
176/// Read the next length-prefixed bincode frame, returning `None` on EOF.
177fn read_row<R: Read, T: DeserializeOwned>(reader: &mut R) -> Result<Option<T>> {
178    let mut len_buf = [0u8; 8];
179    match reader.read_exact(&mut len_buf) {
180        Ok(()) => {}
181        Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
182        Err(e) => return Err(Error::Io(e)),
183    }
184    let len = u64::from_le_bytes(len_buf) as usize;
185    let mut data = vec![0u8; len];
186    reader.read_exact(&mut data).map_err(Error::Io)?;
187    let row: T = bincode::deserialize(&data)
188        .map_err(|e| Error::Corruption(format!("bincode decode: {e}")))?;
189    Ok(Some(row))
190}
191
192// ---------------------------------------------------------------------------
193// RunReader — reads rows from a single sorted run file
194// ---------------------------------------------------------------------------
195
196struct RunReader<T> {
197    _tmpfile: NamedTempFile, // kept alive so the file is auto-deleted on drop
198    reader: BufReader<std::fs::File>,
199    _marker: std::marker::PhantomData<T>,
200}
201
202impl<T: DeserializeOwned> RunReader<T> {
203    fn new(tmp: NamedTempFile) -> Result<Self> {
204        // Reopen a second file descriptor for reading; the original
205        // NamedTempFile stays alive in `_tmpfile` and deletes the file on drop.
206        let read_handle = tmp.reopen().map_err(Error::Io)?;
207        Ok(RunReader {
208            _tmpfile: tmp,
209            reader: BufReader::new(read_handle),
210            _marker: std::marker::PhantomData,
211        })
212    }
213
214    fn next_row(&mut self) -> Result<Option<T>> {
215        read_row(&mut self.reader)
216    }
217}
218
219// ---------------------------------------------------------------------------
220// HeapEntry — wrapper for the k-way merge min-heap
221// ---------------------------------------------------------------------------
222
223struct HeapEntry<T: Ord> {
224    row: Reverse<T>,
225    run_idx: usize,
226}
227
228impl<T: Ord> PartialEq for HeapEntry<T> {
229    fn eq(&self, other: &Self) -> bool {
230        self.row == other.row
231    }
232}
233impl<T: Ord> Eq for HeapEntry<T> {}
234impl<T: Ord> PartialOrd for HeapEntry<T> {
235    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
236        Some(self.cmp(other))
237    }
238}
239impl<T: Ord> Ord for HeapEntry<T> {
240    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
241        // BinaryHeap is a max-heap; Reverse makes it a min-heap on T.
242        self.row.cmp(&other.row)
243    }
244}
245
246// ---------------------------------------------------------------------------
247// SortedOutput — unifies the two output paths
248// ---------------------------------------------------------------------------
249
250enum SortedOutput<T: Ord + DeserializeOwned> {
251    Memory(std::vec::IntoIter<T>),
252    Merge(MergeIter<T>),
253}
254
255impl<T: Ord + DeserializeOwned> Iterator for SortedOutput<T> {
256    type Item = T;
257
258    fn next(&mut self) -> Option<T> {
259        match self {
260            SortedOutput::Memory(it) => it.next(),
261            SortedOutput::Merge(m) => m.next(),
262        }
263    }
264}
265
266// ---------------------------------------------------------------------------
267// MergeIter — k-way merge iterator
268// ---------------------------------------------------------------------------
269
270struct MergeIter<T: Ord + DeserializeOwned> {
271    heap: BinaryHeap<HeapEntry<T>>,
272    readers: Vec<RunReader<T>>,
273    exhausted: bool,
274}
275
276impl<T: Ord + DeserializeOwned> Iterator for MergeIter<T> {
277    type Item = T;
278
279    fn next(&mut self) -> Option<T> {
280        if self.exhausted {
281            return None;
282        }
283        let entry = self.heap.pop()?;
284        let row = entry.row.0;
285        let idx = entry.run_idx;
286
287        // Refill from the same run.
288        match self.readers[idx].next_row() {
289            Ok(Some(next_row)) => {
290                self.heap.push(HeapEntry {
291                    row: Reverse(next_row),
292                    run_idx: idx,
293                });
294            }
295            Ok(None) => { /* run exhausted */ }
296            Err(_) => {
297                self.exhausted = true;
298            }
299        }
300
301        Some(row)
302    }
303}
304
305// ---------------------------------------------------------------------------
306// SortableRow — pre-computed sort key + row payload (SPA-100).
307// ---------------------------------------------------------------------------
308
309use crate::types::Value;
310
311/// `Ord`-safe wrapper for a single ORDER BY key value.
312#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313pub enum OrdValue {
314    Null,
315    Bool(bool),
316    Int64(i64),
317    Float64(u64),
318    String(String),
319    Other,
320}
321
322impl OrdValue {
323    pub fn from_value(v: &Value) -> Self {
324        match v {
325            Value::Null => OrdValue::Null,
326            Value::Bool(b) => OrdValue::Bool(*b),
327            Value::Int64(i) => OrdValue::Int64(*i),
328            Value::Float64(f) => OrdValue::Float64(f.to_bits()),
329            Value::String(s) => OrdValue::String(s.clone()),
330            _ => OrdValue::Other,
331        }
332    }
333
334    fn discriminant(&self) -> u8 {
335        match self {
336            OrdValue::Null => 0,
337            OrdValue::Bool(_) => 1,
338            OrdValue::Int64(_) => 2,
339            OrdValue::Float64(_) => 3,
340            OrdValue::String(_) => 4,
341            OrdValue::Other => 5,
342        }
343    }
344}
345
346impl PartialOrd for OrdValue {
347    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
348        Some(self.cmp(other))
349    }
350}
351
352impl Ord for OrdValue {
353    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
354        match (self, other) {
355            (OrdValue::Null, OrdValue::Null) => std::cmp::Ordering::Equal,
356            (OrdValue::Bool(a), OrdValue::Bool(b)) => a.cmp(b),
357            (OrdValue::Int64(a), OrdValue::Int64(b)) => a.cmp(b),
358            (OrdValue::Float64(a), OrdValue::Float64(b)) => {
359                let ord_bits = |bits: u64| -> u64 {
360                    if bits >> 63 == 1 {
361                        !bits
362                    } else {
363                        bits | (1u64 << 63)
364                    }
365                };
366                ord_bits(*a).cmp(&ord_bits(*b))
367            }
368            (OrdValue::String(a), OrdValue::String(b)) => a.cmp(b),
369            _ => self.discriminant().cmp(&other.discriminant()),
370        }
371    }
372}
373
374/// A single ORDER BY key entry that encodes direction in the variant.
375#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
376pub enum SortKeyVal {
377    Asc(OrdValue),
378    Desc(Reverse<OrdValue>),
379}
380
381impl PartialOrd for SortKeyVal {
382    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
383        Some(self.cmp(other))
384    }
385}
386
387impl Ord for SortKeyVal {
388    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
389        match (self, other) {
390            (SortKeyVal::Asc(a), SortKeyVal::Asc(b)) => a.cmp(b),
391            (SortKeyVal::Desc(a), SortKeyVal::Desc(b)) => a.cmp(b),
392            _ => std::cmp::Ordering::Equal,
393        }
394    }
395}
396
397/// Row wrapped with a pre-computed sort key for use with `SpillingSorter`.
398///
399/// `Ord` is defined by `key` only; `data` is the payload and ignored during
400/// comparison so that the k-way merge produces a correctly-ordered result.
401#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct SortableRow {
403    pub key: Vec<SortKeyVal>,
404    pub data: Vec<Value>,
405}
406
407impl PartialEq for SortableRow {
408    fn eq(&self, other: &Self) -> bool {
409        self.key == other.key
410    }
411}
412
413impl Eq for SortableRow {}
414
415impl PartialOrd for SortableRow {
416    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
417        Some(self.cmp(other))
418    }
419}
420
421impl Ord for SortableRow {
422    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
423        self.key.cmp(&other.key)
424    }
425}
426
427// ---------------------------------------------------------------------------
428// Tests
429// ---------------------------------------------------------------------------
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    /// Sort 1,000 rows that fit entirely in memory.
436    #[test]
437    fn sort_fits_in_memory() {
438        let mut sorter: SpillingSorter<i64> = SpillingSorter::new();
439        // Push in reverse order.
440        for i in (0i64..1_000).rev() {
441            sorter.push(i).unwrap();
442        }
443        let result: Vec<i64> = sorter.finish().unwrap().collect();
444        let expected: Vec<i64> = (0..1_000).collect();
445        assert_eq!(result, expected);
446    }
447
448    /// Sort more than the row threshold, triggering at least one spill.
449    #[test]
450    fn sort_spills_to_disk() {
451        // Use a tiny threshold so we definitely spill.
452        let mut sorter: SpillingSorter<i64> = SpillingSorter::with_thresholds(100, usize::MAX);
453
454        let n = 500i64;
455        for i in (0..n).rev() {
456            sorter.push(i).unwrap();
457        }
458        // Verify that spill files were actually created.
459        assert!(!sorter.runs.is_empty(), "expected at least one spill run");
460
461        let result: Vec<i64> = sorter.finish().unwrap().collect();
462        let expected: Vec<i64> = (0..n).collect();
463        assert_eq!(result, expected);
464    }
465
466    /// Empty input produces empty output.
467    #[test]
468    fn sort_empty() {
469        let sorter: SpillingSorter<i64> = SpillingSorter::new();
470        let result: Vec<i64> = sorter.finish().unwrap().collect();
471        assert!(result.is_empty());
472    }
473
474    /// Verify that spill temp files are cleaned up after finish() completes.
475    /// The simplest correctness check: a spilling sort returns the right output,
476    /// which would fail or corrupt data if the RunReader lost its file handle.
477    #[test]
478    fn sort_spill_no_temp_files_remain() {
479        let mut sorter: SpillingSorter<u64> = SpillingSorter::with_thresholds(10, usize::MAX);
480        for i in 0..50u64 {
481            sorter.push(50 - i).unwrap();
482        }
483        let result: Vec<u64> = sorter.finish().unwrap().collect();
484        assert_eq!(result, (1..=50u64).collect::<Vec<_>>());
485    }
486
487    /// Multi-column sort: tuples (key, value) sorted by key.
488    #[test]
489    fn sort_tuples() {
490        use serde::{Deserialize, Serialize};
491
492        #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
493        struct Row {
494            key: i64,
495            val: String,
496        }
497
498        let mut sorter: SpillingSorter<Row> = SpillingSorter::with_thresholds(3, usize::MAX);
499
500        let rows = vec![
501            Row {
502                key: 3,
503                val: "c".into(),
504            },
505            Row {
506                key: 1,
507                val: "a".into(),
508            },
509            Row {
510                key: 2,
511                val: "b".into(),
512            },
513            Row {
514                key: 5,
515                val: "e".into(),
516            },
517            Row {
518                key: 4,
519                val: "d".into(),
520            },
521        ];
522        for r in rows {
523            sorter.push(r).unwrap();
524        }
525        let result: Vec<Row> = sorter.finish().unwrap().collect();
526        assert_eq!(result[0].key, 1);
527        assert_eq!(result[4].key, 5);
528    }
529}