Skip to main content

sparrowdb_execution/
operators.rs

1//! Execution operators: LabelScan, Filter, Project, Expand.
2//!
3//! Each operator implements a `next_chunk()` iterator pattern returning
4//! `Option<FactorizedChunk>`.
5
6use std::collections::HashMap;
7
8use sparrowdb_common::{NodeId, Result};
9use sparrowdb_storage::node_store::NodeStore;
10
11use crate::types::{FactorizedChunk, TypedVector, Value, VectorGroup};
12
13// ── Operator trait ────────────────────────────────────────────────────────────
14
15/// Operator trait: produces `FactorizedChunk`s lazily.
16pub trait Operator {
17    fn next_chunk(&mut self) -> Result<Option<FactorizedChunk>>;
18
19    /// Drain all chunks and materialize as a Vec<FactorizedChunk>.
20    fn collect_all(&mut self) -> Result<Vec<FactorizedChunk>> {
21        let mut result = Vec::new();
22        while let Some(chunk) = self.next_chunk()? {
23            result.push(chunk);
24        }
25        Ok(result)
26    }
27}
28
29// ── LabelScan ────────────────────────────────────────────────────────────────
30
31/// Scans all slots for a given label_id, reading specified columns.
32///
33/// Each call to `next_chunk()` returns one chunk containing one VectorGroup
34/// with all rows for the label (for simplicity in Phase 4; chunked in Phase 5+).
35pub struct LabelScan<'a> {
36    store: &'a NodeStore,
37    label_id: u32,
38    col_ids: Vec<u32>,
39    done: bool,
40}
41
42impl<'a> LabelScan<'a> {
43    pub fn new(store: &'a NodeStore, label_id: u32, col_ids: &[u32]) -> Self {
44        LabelScan {
45            store,
46            label_id,
47            col_ids: col_ids.to_vec(),
48            done: false,
49        }
50    }
51}
52
53impl<'a> Operator for LabelScan<'a> {
54    fn next_chunk(&mut self) -> Result<Option<FactorizedChunk>> {
55        if self.done {
56            return Ok(None);
57        }
58        self.done = true;
59
60        // Read the HWM to know how many slots exist.
61        let hwm = self.store.hwm_for_label(self.label_id)?;
62        if hwm == 0 {
63            return Ok(Some(FactorizedChunk::new()));
64        }
65
66        let n = hwm as usize;
67        let mut col_vecs: HashMap<String, Vec<i64>> = HashMap::new();
68        let mut node_ids: Vec<NodeId> = Vec::with_capacity(n);
69
70        // Read each column for all slots.
71        for &col_id in &self.col_ids {
72            let mut vals = Vec::with_capacity(n);
73            for slot in 0..hwm as u32 {
74                let node_id = NodeId(((self.label_id as u64) << 32) | (slot as u64));
75                let raw = self.store.get_node_raw(node_id, &[col_id])?;
76                vals.push(if raw.is_empty() { 0 } else { raw[0].1 as i64 });
77            }
78            col_vecs.insert(format!("col_{col_id}"), vals);
79        }
80
81        // Build node_id vector.
82        for slot in 0..hwm as u32 {
83            let node_id = NodeId(((self.label_id as u64) << 32) | (slot as u64));
84            node_ids.push(node_id);
85        }
86
87        let mut group = VectorGroup::new(1);
88        group.add_column("__node_id__".into(), TypedVector::NodeRef(node_ids));
89        for (name, vals) in col_vecs {
90            group.add_column(name, TypedVector::Int64(vals));
91        }
92
93        let mut chunk = FactorizedChunk::new();
94        chunk.push_group(group);
95        Ok(Some(chunk))
96    }
97}
98
99// ── Filter ────────────────────────────────────────────────────────────────────
100
101/// Filters rows in each chunk where `column_name == predicate_value`.
102pub struct Filter<'a, O: Operator + 'a> {
103    inner: &'a mut O,
104    column: String,
105    predicate: FilterPredicate,
106}
107
108/// Filter predicate types.
109pub enum FilterPredicate {
110    /// column == value
111    Eq(Value),
112    /// column CONTAINS string
113    Contains(String),
114    /// column > value (int)
115    Gt(i64),
116    /// column >= value
117    Ge(i64),
118    /// column < value
119    Lt(i64),
120}
121
122impl<'a, O: Operator> Filter<'a, O> {
123    /// Construct a filter: `column_name == value`.
124    pub fn new(inner: &'a mut O, column_name: &str, value: Value) -> Self {
125        Filter {
126            inner,
127            column: column_name.to_string(),
128            predicate: FilterPredicate::Eq(value),
129        }
130    }
131
132    /// Construct a CONTAINS filter.
133    pub fn contains(inner: &'a mut O, column_name: &str, substr: &str) -> Self {
134        Filter {
135            inner,
136            column: column_name.to_string(),
137            predicate: FilterPredicate::Contains(substr.to_string()),
138        }
139    }
140
141    /// Construct a greater-than filter (int64).
142    pub fn gt(inner: &'a mut O, column_name: &str, val: i64) -> Self {
143        Filter {
144            inner,
145            column: column_name.to_string(),
146            predicate: FilterPredicate::Gt(val),
147        }
148    }
149
150    fn matches(&self, v: &Value) -> bool {
151        match &self.predicate {
152            FilterPredicate::Eq(expected) => v == expected,
153            FilterPredicate::Contains(substr) => match v {
154                Value::String(s) => s.contains(substr.as_str()),
155                _ => false,
156            },
157            FilterPredicate::Gt(thresh) => match v {
158                Value::Int64(n) => *n > *thresh,
159                _ => false,
160            },
161            FilterPredicate::Ge(thresh) => match v {
162                Value::Int64(n) => *n >= *thresh,
163                _ => false,
164            },
165            FilterPredicate::Lt(thresh) => match v {
166                Value::Int64(n) => *n < *thresh,
167                _ => false,
168            },
169        }
170    }
171
172    fn filter_group(&self, group: VectorGroup) -> Option<VectorGroup> {
173        let col = group.columns.get(&self.column)?;
174        let n = col.len();
175
176        // Compute keep mask.
177        let keep: Vec<bool> = (0..n).map(|i| self.matches(&col.get(i))).collect();
178
179        if keep.iter().all(|&k| !k) {
180            return None;
181        }
182
183        let mut new_group = VectorGroup::new(group.multiplicity);
184        for (col_name, col_vec) in &group.columns {
185            let filtered = filter_typed_vector(col_vec, &keep);
186            new_group.add_column(col_name.clone(), filtered);
187        }
188        if new_group.is_empty() {
189            None
190        } else {
191            Some(new_group)
192        }
193    }
194}
195
196impl<'a, O: Operator> Operator for Filter<'a, O> {
197    fn next_chunk(&mut self) -> Result<Option<FactorizedChunk>> {
198        loop {
199            match self.inner.next_chunk()? {
200                None => return Ok(None),
201                Some(chunk) => {
202                    let mut out = FactorizedChunk::new();
203                    for group in chunk.groups {
204                        if let Some(filtered) = self.filter_group(group) {
205                            out.push_group(filtered);
206                        }
207                    }
208                    if !out.is_empty() {
209                        return Ok(Some(out));
210                    }
211                    // If all groups were filtered out, ask for the next chunk.
212                }
213            }
214        }
215    }
216}
217
218fn filter_typed_vector(vec: &TypedVector, keep: &[bool]) -> TypedVector {
219    match vec {
220        TypedVector::Int64(v) => TypedVector::Int64(
221            v.iter()
222                .zip(keep)
223                .filter_map(|(x, &k)| if k { Some(*x) } else { None })
224                .collect(),
225        ),
226        TypedVector::Float64(v) => TypedVector::Float64(
227            v.iter()
228                .zip(keep)
229                .filter_map(|(x, &k)| if k { Some(*x) } else { None })
230                .collect(),
231        ),
232        TypedVector::Bool(v) => TypedVector::Bool(
233            v.iter()
234                .zip(keep)
235                .filter_map(|(x, &k)| if k { Some(*x) } else { None })
236                .collect(),
237        ),
238        TypedVector::String(v) => TypedVector::String(
239            v.iter()
240                .zip(keep)
241                .filter_map(|(x, &k)| if k { Some(x.clone()) } else { None })
242                .collect(),
243        ),
244        TypedVector::NodeRef(v) => TypedVector::NodeRef(
245            v.iter()
246                .zip(keep)
247                .filter_map(|(x, &k)| if k { Some(*x) } else { None })
248                .collect(),
249        ),
250        TypedVector::EdgeRef(v) => TypedVector::EdgeRef(
251            v.iter()
252                .zip(keep)
253                .filter_map(|(x, &k)| if k { Some(*x) } else { None })
254                .collect(),
255        ),
256    }
257}
258
259// ── Project ───────────────────────────────────────────────────────────────────
260
261/// Projects (selects) specific named columns from each chunk.
262pub struct Project<'a, O: Operator + 'a> {
263    inner: &'a mut O,
264    columns: Vec<String>,
265}
266
267impl<'a, O: Operator> Project<'a, O> {
268    pub fn new(inner: &'a mut O, columns: Vec<String>) -> Self {
269        Project { inner, columns }
270    }
271}
272
273impl<'a, O: Operator> Operator for Project<'a, O> {
274    fn next_chunk(&mut self) -> Result<Option<FactorizedChunk>> {
275        match self.inner.next_chunk()? {
276            None => Ok(None),
277            Some(chunk) => {
278                let mut out = FactorizedChunk::new();
279                for group in chunk.groups {
280                    let mut new_group = VectorGroup::new(group.multiplicity);
281                    for col_name in &self.columns {
282                        if let Some(col) = group.columns.get(col_name) {
283                            new_group.add_column(col_name.clone(), col.clone());
284                        }
285                    }
286                    out.push_group(new_group);
287                }
288                Ok(Some(out))
289            }
290        }
291    }
292}
293
294// ── Expand ────────────────────────────────────────────────────────────────────
295
296/// Expands a NodeRef column by looking up neighbors in a CSR.
297///
298/// For each node in `src_col`, produces neighbor node IDs in `dst_col`.
299/// Preserves group multiplicity.
300pub struct Expand<'a, O: Operator + 'a> {
301    inner: &'a mut O,
302    src_col: String,
303    dst_col: String,
304    csr: &'a sparrowdb_storage::csr::CsrForward,
305    // Buffered output chunks from expansion.
306    buffer: Vec<FactorizedChunk>,
307    done: bool,
308}
309
310impl<'a, O: Operator> Expand<'a, O> {
311    pub fn new(
312        inner: &'a mut O,
313        src_col: &str,
314        dst_col: &str,
315        csr: &'a sparrowdb_storage::csr::CsrForward,
316    ) -> Self {
317        Expand {
318            inner,
319            src_col: src_col.to_string(),
320            dst_col: dst_col.to_string(),
321            csr,
322            buffer: Vec::new(),
323            done: false,
324        }
325    }
326}
327
328impl<'a, O: Operator> Operator for Expand<'a, O> {
329    fn next_chunk(&mut self) -> Result<Option<FactorizedChunk>> {
330        if !self.buffer.is_empty() {
331            return Ok(Some(self.buffer.remove(0)));
332        }
333        if self.done {
334            return Ok(None);
335        }
336
337        match self.inner.next_chunk()? {
338            None => {
339                self.done = true;
340                Ok(None)
341            }
342            Some(chunk) => {
343                let mut out = FactorizedChunk::new();
344                for group in chunk.groups {
345                    let node_col = match group.columns.get(&self.src_col) {
346                        Some(TypedVector::NodeRef(v)) => v.clone(),
347                        _ => continue,
348                    };
349
350                    // For each source node, expand to neighbors.
351                    for src_node in &node_col {
352                        let slot = src_node.0 & 0xFFFF_FFFF;
353                        let neighbors = self.csr.neighbors(slot);
354                        if neighbors.is_empty() {
355                            continue;
356                        }
357
358                        // Build a new group for each src with all its neighbors.
359                        let label_id = src_node.0 >> 32;
360                        let dst_nodes: Vec<NodeId> = neighbors
361                            .iter()
362                            .map(|&nb_slot| NodeId((label_id << 32) | nb_slot))
363                            .collect();
364
365                        let n = dst_nodes.len();
366                        let mut new_group = VectorGroup::new(group.multiplicity);
367                        // Repeat the source node N times to match the destination vector length,
368                        // preserving the VectorGroup invariant that all columns have equal length.
369                        new_group.add_column(
370                            self.src_col.clone(),
371                            TypedVector::NodeRef(vec![*src_node; n]),
372                        );
373                        new_group.add_column(self.dst_col.clone(), TypedVector::NodeRef(dst_nodes));
374                        out.push_group(new_group);
375                    }
376                }
377                if out.is_empty() {
378                    // Recurse to get next chunk with data.
379                    self.next_chunk()
380                } else {
381                    Ok(Some(out))
382                }
383            }
384        }
385    }
386}
387
388// ── UnwindOperator ────────────────────────────────────────────────────────────
389
390/// Iterates a list of scalar `Value`s, emitting one row per element.
391///
392/// Each row has a single column named after `alias`.
393/// Empty lists produce zero rows.
394pub struct UnwindOperator {
395    /// Pre-evaluated list of values to iterate.
396    values: Vec<crate::types::Value>,
397    /// Column name bound to each element.
398    alias: String,
399    /// Index of the next value to emit.
400    idx: usize,
401    done: bool,
402}
403
404impl UnwindOperator {
405    /// Create an UNWIND operator that emits each element of `values` in turn.
406    pub fn new(alias: String, values: Vec<crate::types::Value>) -> Self {
407        let done = values.is_empty();
408        UnwindOperator {
409            values,
410            alias,
411            idx: 0,
412            done,
413        }
414    }
415}
416
417impl Operator for UnwindOperator {
418    fn next_chunk(&mut self) -> Result<Option<FactorizedChunk>> {
419        if self.done {
420            return Ok(None);
421        }
422
423        // Emit all elements in a single chunk as typed vectors.
424        // We detect the type from the first element and coerce the rest;
425        // mixed-type lists produce Int64 / String / Float64 chunks respectively.
426        let remaining = &self.values[self.idx..];
427        if remaining.is_empty() {
428            self.done = true;
429            return Ok(None);
430        }
431
432        // Build a TypedVector matching the dominant type.
433        let typed = build_typed_vector(remaining);
434        self.idx = self.values.len();
435        self.done = true;
436
437        let mut group = VectorGroup::new(1);
438        group.add_column(self.alias.clone(), typed);
439        let mut chunk = FactorizedChunk::new();
440        chunk.push_group(group);
441        Ok(Some(chunk))
442    }
443}
444
445/// Convert a slice of `Value`s into a `TypedVector`.
446///
447/// If all values are the same primitive type, uses that type's vector;
448/// otherwise falls back to `String` (via `Display`).
449fn build_typed_vector(values: &[crate::types::Value]) -> TypedVector {
450    use crate::types::Value;
451
452    // Check if all values are Int64.
453    if values.iter().all(|v| matches!(v, Value::Int64(_))) {
454        return TypedVector::Int64(
455            values
456                .iter()
457                .map(|v| match v {
458                    Value::Int64(n) => *n,
459                    _ => unreachable!(),
460                })
461                .collect(),
462        );
463    }
464
465    // Check if all values are Float64.
466    if values.iter().all(|v| matches!(v, Value::Float64(_))) {
467        return TypedVector::Float64(
468            values
469                .iter()
470                .map(|v| match v {
471                    Value::Float64(f) => *f,
472                    _ => unreachable!(),
473                })
474                .collect(),
475        );
476    }
477
478    // Check if all values are Bool.
479    if values.iter().all(|v| matches!(v, Value::Bool(_))) {
480        return TypedVector::Bool(
481            values
482                .iter()
483                .map(|v| match v {
484                    Value::Bool(b) => *b,
485                    _ => unreachable!(),
486                })
487                .collect(),
488        );
489    }
490
491    // Fall back to String.
492    TypedVector::String(values.iter().map(|v| v.to_string()).collect())
493}