sparrowdb_execution/
join_spill.rs1use std::collections::{HashMap, HashSet};
16use std::io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write};
17
18use sparrowdb_common::{Error, Result};
19use sparrowdb_storage::csr::CsrForward;
20use tempfile::NamedTempFile;
21
22pub const SPILL_THRESHOLD: usize = 500_000;
24
25const NUM_PARTITIONS: usize = 16;
27
28pub struct SpillingHashJoin<'a> {
38 csr: &'a CsrForward,
39 spill_threshold: usize,
40 num_partitions: usize,
41}
42
43impl<'a> SpillingHashJoin<'a> {
44 pub fn new(csr: &'a CsrForward) -> Self {
46 SpillingHashJoin {
47 csr,
48 spill_threshold: SPILL_THRESHOLD,
49 num_partitions: NUM_PARTITIONS,
50 }
51 }
52
53 pub fn with_thresholds(
55 csr: &'a CsrForward,
56 spill_threshold: usize,
57 num_partitions: usize,
58 ) -> Self {
59 SpillingHashJoin {
60 csr,
61 spill_threshold,
62 num_partitions: num_partitions.max(1), }
64 }
65
66 pub fn two_hop(&self, src_slot: u64) -> Result<Vec<u64>> {
71 let direct = self.csr.neighbors(src_slot);
72 if direct.is_empty() {
73 return Ok(vec![]);
74 }
75
76 let total_fof_estimate: usize = direct
79 .iter()
80 .map(|&mid| self.csr.neighbors(mid).len())
81 .sum();
82
83 if total_fof_estimate <= self.spill_threshold {
84 return self.two_hop_in_memory(direct);
85 }
86
87 self.two_hop_spilling(direct)
89 }
90
91 fn two_hop_in_memory(&self, direct: &[u64]) -> Result<Vec<u64>> {
94 let mut hash: HashMap<u64, Vec<u64>> = HashMap::new();
95 for &mid in direct {
96 let fof_list = self.csr.neighbors(mid);
97 if !fof_list.is_empty() {
98 hash.entry(mid).or_default().extend_from_slice(fof_list);
99 }
100 }
101
102 let mut fof_set: HashSet<u64> = HashSet::new();
103 for &mid in direct {
104 if let Some(fof_list) = hash.get(&mid) {
105 fof_set.extend(fof_list.iter().copied());
106 }
107 }
108
109 let mut result: Vec<u64> = fof_set.into_iter().collect();
110 result.sort_unstable();
111 Ok(result)
112 }
113
114 fn two_hop_spilling(&self, direct: &[u64]) -> Result<Vec<u64>> {
117 let np = self.num_partitions;
118
119 let mut part_files: Vec<NamedTempFile> = (0..np)
121 .map(|_| NamedTempFile::new().map_err(Error::Io))
122 .collect::<Result<_>>()?;
123
124 {
125 let mut writers: Vec<BufWriter<&mut std::fs::File>> = part_files
126 .iter_mut()
127 .map(|f| BufWriter::new(f.as_file_mut()))
128 .collect();
129
130 for &mid in direct {
131 let fof_list = self.csr.neighbors(mid);
132 if fof_list.is_empty() {
133 continue;
134 }
135 let p = (mid as usize) % np;
136 for &fof in fof_list {
137 write_u64_pair(&mut writers[p], mid, fof)?;
138 }
139 }
140
141 for w in &mut writers {
142 w.flush().map_err(Error::Io)?;
143 }
144 }
145
146 let mut fof_set: HashSet<u64> = HashSet::new();
148
149 for file in &mut part_files {
150 file.as_file_mut()
151 .seek(SeekFrom::Start(0))
152 .map_err(Error::Io)?;
153 let mut reader = BufReader::new(file.as_file_mut());
154
155 let mut hash: HashMap<u64, Vec<u64>> = HashMap::new();
156 while let Some((mid, fof)) = read_u64_pair(&mut reader)? {
157 hash.entry(mid).or_default().push(fof);
158 }
159
160 for fof_list in hash.values() {
161 fof_set.extend(fof_list.iter().copied());
162 }
163 }
164
165 let mut result: Vec<u64> = fof_set.into_iter().collect();
166 result.sort_unstable();
167 Ok(result)
168 }
169}
170
171fn write_u64_pair<W: Write>(w: &mut W, a: u64, b: u64) -> Result<()> {
176 w.write_all(&a.to_le_bytes()).map_err(Error::Io)?;
177 w.write_all(&b.to_le_bytes()).map_err(Error::Io)?;
178 Ok(())
179}
180
181fn read_u64_pair<R: Read>(r: &mut R) -> Result<Option<(u64, u64)>> {
182 let mut buf = [0u8; 8];
183 match r.read_exact(&mut buf) {
184 Ok(()) => {}
185 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
186 Err(e) => return Err(Error::Io(e)),
187 }
188 let a = u64::from_le_bytes(buf);
189 r.read_exact(&mut buf).map_err(Error::Io)?;
190 let b = u64::from_le_bytes(buf);
191 Ok(Some((a, b)))
192}
193
194#[cfg(test)]
199mod tests {
200 use super::*;
201 use crate::join::AspJoin;
202
203 fn social_graph() -> CsrForward {
207 let edges = vec![(0u64, 1u64), (0, 2), (1, 3), (2, 3), (2, 4)];
208 CsrForward::build(5u64, &edges)
209 }
210
211 #[test]
214 fn join_spill_small_graph() {
215 let csr = social_graph();
216 let baseline = AspJoin::new(&csr);
217 let spilling = SpillingHashJoin::new(&csr);
218
219 let expected = baseline.two_hop(0).unwrap();
221 let got = spilling.two_hop(0).unwrap();
222 assert_eq!(got, expected, "Alice fof mismatch");
223
224 let expected_bob = baseline.two_hop(1).unwrap();
227 let got_bob = spilling.two_hop(1).unwrap();
228 assert_eq!(got_bob, expected_bob, "Bob fof mismatch");
229 }
230
231 #[test]
235 fn join_spill_large_graph() {
236 const N: u64 = 10_000;
237
238 let edges: Vec<(u64, u64)> = (0..N).map(|i| (i, (i + 1) % N)).collect();
240 let csr = CsrForward::build(N, &edges);
241
242 let baseline = AspJoin::new(&csr);
244
245 let spilling = SpillingHashJoin::with_thresholds(&csr, 1, 4);
248
249 for src in 0..N {
250 let expected = baseline.two_hop(src).unwrap();
251 let got = spilling.two_hop(src).unwrap();
252 assert_eq!(got, expected, "ring fof mismatch for src={src}");
253 }
254 }
255
256 #[test]
258 fn join_spill_no_edges() {
259 let csr = CsrForward::build(3u64, &[(1u64, 2u64)]);
260 let spilling = SpillingHashJoin::new(&csr);
261 let got = spilling.two_hop(0).unwrap();
262 assert!(got.is_empty());
263 }
264
265 #[test]
267 fn join_spill_zero_partitions_does_not_panic() {
268 let csr = CsrForward::build(3u64, &[(0u64, 1u64), (1u64, 2u64)]);
269 let join = SpillingHashJoin::with_thresholds(&csr, 0, 0);
271 let result = join.two_hop(0).unwrap();
272 assert_eq!(result, vec![2]);
273 }
274}