1use std::cmp::Reverse;
15use std::collections::BinaryHeap;
16use std::io::{BufReader, BufWriter, Read, Write};
17
18use serde::{de::DeserializeOwned, Deserialize, Serialize};
19use sparrowdb_common::{Error, Result};
20use tempfile::NamedTempFile;
21
22pub const DEFAULT_ROW_THRESHOLD: usize = 100_000;
24
25pub const DEFAULT_BYTE_THRESHOLD: usize = 64 * 1024 * 1024;
27
28pub struct SpillingSorter<T> {
35 buffer: Vec<T>,
37 runs: Vec<NamedTempFile>,
39 row_threshold: usize,
41 byte_threshold: usize,
43 byte_estimate: usize,
45 bytes_per_row: usize,
47}
48
49impl<T> SpillingSorter<T>
50where
51 T: Serialize + DeserializeOwned + Ord + Clone,
52{
53 pub fn new() -> Self {
55 SpillingSorter::with_thresholds(DEFAULT_ROW_THRESHOLD, DEFAULT_BYTE_THRESHOLD)
56 }
57
58 pub fn with_thresholds(row_threshold: usize, byte_threshold: usize) -> Self {
61 SpillingSorter {
62 buffer: Vec::new(),
63 runs: Vec::new(),
64 row_threshold,
65 byte_threshold,
66 byte_estimate: 0,
67 bytes_per_row: 64, }
69 }
70
71 pub fn push(&mut self, row: T) -> Result<()> {
74 self.byte_estimate += self.bytes_per_row;
75 self.buffer.push(row);
76
77 if self.buffer.len() >= self.row_threshold || self.byte_estimate >= self.byte_threshold {
78 self.spill()?;
79 }
80 Ok(())
81 }
82
83 pub fn finish(mut self) -> Result<impl Iterator<Item = T>> {
86 if self.runs.is_empty() {
87 self.buffer.sort();
89 return Ok(SortedOutput::Memory(self.buffer.into_iter()));
90 }
91
92 if !self.buffer.is_empty() {
94 self.spill()?;
95 }
96
97 let mut readers: Vec<RunReader<T>> = self
99 .runs
100 .into_iter()
101 .map(RunReader::new)
102 .collect::<Result<Vec<_>>>()?;
103
104 let mut heap: BinaryHeap<HeapEntry<T>> = BinaryHeap::new();
106 for (idx, reader) in readers.iter_mut().enumerate() {
107 if let Some(row) = reader.next_row()? {
108 heap.push(HeapEntry {
109 row: Reverse(row),
110 run_idx: idx,
111 });
112 }
113 }
114
115 Ok(SortedOutput::Merge(MergeIter {
116 heap,
117 readers,
118 exhausted: false,
119 }))
120 }
121
122 fn spill(&mut self) -> Result<()> {
126 self.buffer.sort();
127
128 if let Some(first) = self.buffer.first() {
131 if let Ok(encoded) = bincode::serialize(first) {
132 self.bytes_per_row = encoded.len() + 8;
134 }
135 }
136
137 let mut tmp = NamedTempFile::new().map_err(Error::Io)?;
138 {
139 let mut writer = BufWriter::new(tmp.as_file_mut());
140 for row in &self.buffer {
141 write_row(&mut writer, row)?;
142 }
143 writer.flush().map_err(Error::Io)?;
144 }
145
146 self.runs.push(tmp);
147 self.buffer.clear();
148 self.byte_estimate = 0;
149 Ok(())
150 }
151}
152
153impl<T> Default for SpillingSorter<T>
154where
155 T: Serialize + DeserializeOwned + Ord + Clone,
156{
157 fn default() -> Self {
158 Self::new()
159 }
160}
161
162fn write_row<W: Write, T: Serialize>(writer: &mut W, row: &T) -> Result<()> {
168 let encoded = bincode::serialize(row)
169 .map_err(|e| Error::InvalidArgument(format!("bincode encode: {e}")))?;
170 let len = encoded.len() as u64;
171 writer.write_all(&len.to_le_bytes()).map_err(Error::Io)?;
172 writer.write_all(&encoded).map_err(Error::Io)?;
173 Ok(())
174}
175
176fn read_row<R: Read, T: DeserializeOwned>(reader: &mut R) -> Result<Option<T>> {
178 let mut len_buf = [0u8; 8];
179 match reader.read_exact(&mut len_buf) {
180 Ok(()) => {}
181 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
182 Err(e) => return Err(Error::Io(e)),
183 }
184 let len = u64::from_le_bytes(len_buf) as usize;
185 let mut data = vec![0u8; len];
186 reader.read_exact(&mut data).map_err(Error::Io)?;
187 let row: T = bincode::deserialize(&data)
188 .map_err(|e| Error::Corruption(format!("bincode decode: {e}")))?;
189 Ok(Some(row))
190}
191
192struct RunReader<T> {
197 _tmpfile: NamedTempFile, reader: BufReader<std::fs::File>,
199 _marker: std::marker::PhantomData<T>,
200}
201
202impl<T: DeserializeOwned> RunReader<T> {
203 fn new(tmp: NamedTempFile) -> Result<Self> {
204 let read_handle = tmp.reopen().map_err(Error::Io)?;
207 Ok(RunReader {
208 _tmpfile: tmp,
209 reader: BufReader::new(read_handle),
210 _marker: std::marker::PhantomData,
211 })
212 }
213
214 fn next_row(&mut self) -> Result<Option<T>> {
215 read_row(&mut self.reader)
216 }
217}
218
219struct HeapEntry<T: Ord> {
224 row: Reverse<T>,
225 run_idx: usize,
226}
227
228impl<T: Ord> PartialEq for HeapEntry<T> {
229 fn eq(&self, other: &Self) -> bool {
230 self.row == other.row
231 }
232}
233impl<T: Ord> Eq for HeapEntry<T> {}
234impl<T: Ord> PartialOrd for HeapEntry<T> {
235 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
236 Some(self.cmp(other))
237 }
238}
239impl<T: Ord> Ord for HeapEntry<T> {
240 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
241 self.row.cmp(&other.row)
243 }
244}
245
246enum SortedOutput<T: Ord + DeserializeOwned> {
251 Memory(std::vec::IntoIter<T>),
252 Merge(MergeIter<T>),
253}
254
255impl<T: Ord + DeserializeOwned> Iterator for SortedOutput<T> {
256 type Item = T;
257
258 fn next(&mut self) -> Option<T> {
259 match self {
260 SortedOutput::Memory(it) => it.next(),
261 SortedOutput::Merge(m) => m.next(),
262 }
263 }
264}
265
266struct MergeIter<T: Ord + DeserializeOwned> {
271 heap: BinaryHeap<HeapEntry<T>>,
272 readers: Vec<RunReader<T>>,
273 exhausted: bool,
274}
275
276impl<T: Ord + DeserializeOwned> Iterator for MergeIter<T> {
277 type Item = T;
278
279 fn next(&mut self) -> Option<T> {
280 if self.exhausted {
281 return None;
282 }
283 let entry = self.heap.pop()?;
284 let row = entry.row.0;
285 let idx = entry.run_idx;
286
287 match self.readers[idx].next_row() {
289 Ok(Some(next_row)) => {
290 self.heap.push(HeapEntry {
291 row: Reverse(next_row),
292 run_idx: idx,
293 });
294 }
295 Ok(None) => { }
296 Err(_) => {
297 self.exhausted = true;
298 }
299 }
300
301 Some(row)
302 }
303}
304
305use crate::types::Value;
310
311#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313pub enum OrdValue {
314 Null,
315 Bool(bool),
316 Int64(i64),
317 Float64(u64),
318 String(String),
319 Other,
320}
321
322impl OrdValue {
323 pub fn from_value(v: &Value) -> Self {
324 match v {
325 Value::Null => OrdValue::Null,
326 Value::Bool(b) => OrdValue::Bool(*b),
327 Value::Int64(i) => OrdValue::Int64(*i),
328 Value::Float64(f) => OrdValue::Float64(f.to_bits()),
329 Value::String(s) => OrdValue::String(s.clone()),
330 _ => OrdValue::Other,
331 }
332 }
333
334 fn discriminant(&self) -> u8 {
335 match self {
336 OrdValue::Null => 0,
337 OrdValue::Bool(_) => 1,
338 OrdValue::Int64(_) => 2,
339 OrdValue::Float64(_) => 3,
340 OrdValue::String(_) => 4,
341 OrdValue::Other => 5,
342 }
343 }
344}
345
346impl PartialOrd for OrdValue {
347 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
348 Some(self.cmp(other))
349 }
350}
351
352impl Ord for OrdValue {
353 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
354 match (self, other) {
355 (OrdValue::Null, OrdValue::Null) => std::cmp::Ordering::Equal,
356 (OrdValue::Bool(a), OrdValue::Bool(b)) => a.cmp(b),
357 (OrdValue::Int64(a), OrdValue::Int64(b)) => a.cmp(b),
358 (OrdValue::Float64(a), OrdValue::Float64(b)) => {
359 let ord_bits = |bits: u64| -> u64 {
360 if bits >> 63 == 1 {
361 !bits
362 } else {
363 bits | (1u64 << 63)
364 }
365 };
366 ord_bits(*a).cmp(&ord_bits(*b))
367 }
368 (OrdValue::String(a), OrdValue::String(b)) => a.cmp(b),
369 _ => self.discriminant().cmp(&other.discriminant()),
370 }
371 }
372}
373
374#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
376pub enum SortKeyVal {
377 Asc(OrdValue),
378 Desc(Reverse<OrdValue>),
379}
380
381impl PartialOrd for SortKeyVal {
382 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
383 Some(self.cmp(other))
384 }
385}
386
387impl Ord for SortKeyVal {
388 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
389 match (self, other) {
390 (SortKeyVal::Asc(a), SortKeyVal::Asc(b)) => a.cmp(b),
391 (SortKeyVal::Desc(a), SortKeyVal::Desc(b)) => a.cmp(b),
392 _ => std::cmp::Ordering::Equal,
393 }
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
402pub struct SortableRow {
403 pub key: Vec<SortKeyVal>,
404 pub data: Vec<Value>,
405}
406
407impl PartialEq for SortableRow {
408 fn eq(&self, other: &Self) -> bool {
409 self.key == other.key
410 }
411}
412
413impl Eq for SortableRow {}
414
415impl PartialOrd for SortableRow {
416 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
417 Some(self.cmp(other))
418 }
419}
420
421impl Ord for SortableRow {
422 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
423 self.key.cmp(&other.key)
424 }
425}
426
427#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
437 fn sort_fits_in_memory() {
438 let mut sorter: SpillingSorter<i64> = SpillingSorter::new();
439 for i in (0i64..1_000).rev() {
441 sorter.push(i).unwrap();
442 }
443 let result: Vec<i64> = sorter.finish().unwrap().collect();
444 let expected: Vec<i64> = (0..1_000).collect();
445 assert_eq!(result, expected);
446 }
447
448 #[test]
450 fn sort_spills_to_disk() {
451 let mut sorter: SpillingSorter<i64> = SpillingSorter::with_thresholds(100, usize::MAX);
453
454 let n = 500i64;
455 for i in (0..n).rev() {
456 sorter.push(i).unwrap();
457 }
458 assert!(!sorter.runs.is_empty(), "expected at least one spill run");
460
461 let result: Vec<i64> = sorter.finish().unwrap().collect();
462 let expected: Vec<i64> = (0..n).collect();
463 assert_eq!(result, expected);
464 }
465
466 #[test]
468 fn sort_empty() {
469 let sorter: SpillingSorter<i64> = SpillingSorter::new();
470 let result: Vec<i64> = sorter.finish().unwrap().collect();
471 assert!(result.is_empty());
472 }
473
474 #[test]
478 fn sort_spill_no_temp_files_remain() {
479 let mut sorter: SpillingSorter<u64> = SpillingSorter::with_thresholds(10, usize::MAX);
480 for i in 0..50u64 {
481 sorter.push(50 - i).unwrap();
482 }
483 let result: Vec<u64> = sorter.finish().unwrap().collect();
484 assert_eq!(result, (1..=50u64).collect::<Vec<_>>());
485 }
486
487 #[test]
489 fn sort_tuples() {
490 use serde::{Deserialize, Serialize};
491
492 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
493 struct Row {
494 key: i64,
495 val: String,
496 }
497
498 let mut sorter: SpillingSorter<Row> = SpillingSorter::with_thresholds(3, usize::MAX);
499
500 let rows = vec![
501 Row {
502 key: 3,
503 val: "c".into(),
504 },
505 Row {
506 key: 1,
507 val: "a".into(),
508 },
509 Row {
510 key: 2,
511 val: "b".into(),
512 },
513 Row {
514 key: 5,
515 val: "e".into(),
516 },
517 Row {
518 key: 4,
519 val: "d".into(),
520 },
521 ];
522 for r in rows {
523 sorter.push(r).unwrap();
524 }
525 let result: Vec<Row> = sorter.finish().unwrap().collect();
526 assert_eq!(result[0].key, 1);
527 assert_eq!(result[4].key, 5);
528 }
529}