Skip to main content

ruvector_dag/dag/
operator_node.rs

1//! Operator node types and definitions for query DAG
2
3use serde::{Deserialize, Serialize};
4
5/// Types of operators in a query DAG
6#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7pub enum OperatorType {
8    // Scan operators
9    SeqScan {
10        table: String,
11    },
12    IndexScan {
13        index: String,
14        table: String,
15    },
16    HnswScan {
17        index: String,
18        ef_search: u32,
19    },
20    IvfFlatScan {
21        index: String,
22        nprobe: u32,
23    },
24
25    // Join operators
26    NestedLoopJoin,
27    HashJoin {
28        hash_key: String,
29    },
30    MergeJoin {
31        merge_key: String,
32    },
33
34    // Aggregation
35    Aggregate {
36        functions: Vec<String>,
37    },
38    GroupBy {
39        keys: Vec<String>,
40    },
41
42    // Filter/Project
43    Filter {
44        predicate: String,
45    },
46    Project {
47        columns: Vec<String>,
48    },
49
50    // Sort/Limit
51    Sort {
52        keys: Vec<String>,
53        descending: Vec<bool>,
54    },
55    Limit {
56        count: usize,
57    },
58
59    // Vector operations
60    VectorDistance {
61        metric: String,
62    },
63    Rerank {
64        model: String,
65    },
66
67    // Utility
68    Materialize,
69    Result,
70
71    // Backward compatibility variants (deprecated, use specific variants above)
72    #[deprecated(note = "Use SeqScan instead")]
73    Scan,
74    #[deprecated(note = "Use HashJoin or NestedLoopJoin instead")]
75    Join,
76}
77
78/// A node in the query DAG
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct OperatorNode {
81    pub id: usize,
82    pub op_type: OperatorType,
83    pub estimated_rows: f64,
84    pub estimated_cost: f64,
85    pub actual_rows: Option<f64>,
86    pub actual_time_ms: Option<f64>,
87    pub embedding: Option<Vec<f32>>,
88}
89
90impl OperatorNode {
91    /// Create a new operator node
92    pub fn new(id: usize, op_type: OperatorType) -> Self {
93        Self {
94            id,
95            op_type,
96            estimated_rows: 0.0,
97            estimated_cost: 0.0,
98            actual_rows: None,
99            actual_time_ms: None,
100            embedding: None,
101        }
102    }
103
104    /// Create a sequential scan node
105    pub fn seq_scan(id: usize, table: &str) -> Self {
106        Self::new(
107            id,
108            OperatorType::SeqScan {
109                table: table.to_string(),
110            },
111        )
112    }
113
114    /// Create an index scan node
115    pub fn index_scan(id: usize, index: &str, table: &str) -> Self {
116        Self::new(
117            id,
118            OperatorType::IndexScan {
119                index: index.to_string(),
120                table: table.to_string(),
121            },
122        )
123    }
124
125    /// Create an HNSW scan node
126    pub fn hnsw_scan(id: usize, index: &str, ef_search: u32) -> Self {
127        Self::new(
128            id,
129            OperatorType::HnswScan {
130                index: index.to_string(),
131                ef_search,
132            },
133        )
134    }
135
136    /// Create an IVF-Flat scan node
137    pub fn ivf_flat_scan(id: usize, index: &str, nprobe: u32) -> Self {
138        Self::new(
139            id,
140            OperatorType::IvfFlatScan {
141                index: index.to_string(),
142                nprobe,
143            },
144        )
145    }
146
147    /// Create a nested loop join node
148    pub fn nested_loop_join(id: usize) -> Self {
149        Self::new(id, OperatorType::NestedLoopJoin)
150    }
151
152    /// Create a hash join node
153    pub fn hash_join(id: usize, key: &str) -> Self {
154        Self::new(
155            id,
156            OperatorType::HashJoin {
157                hash_key: key.to_string(),
158            },
159        )
160    }
161
162    /// Create a merge join node
163    pub fn merge_join(id: usize, key: &str) -> Self {
164        Self::new(
165            id,
166            OperatorType::MergeJoin {
167                merge_key: key.to_string(),
168            },
169        )
170    }
171
172    /// Create a filter node
173    pub fn filter(id: usize, predicate: &str) -> Self {
174        Self::new(
175            id,
176            OperatorType::Filter {
177                predicate: predicate.to_string(),
178            },
179        )
180    }
181
182    /// Create a project node
183    pub fn project(id: usize, columns: Vec<String>) -> Self {
184        Self::new(id, OperatorType::Project { columns })
185    }
186
187    /// Create a sort node
188    pub fn sort(id: usize, keys: Vec<String>) -> Self {
189        let descending = vec![false; keys.len()];
190        Self::new(id, OperatorType::Sort { keys, descending })
191    }
192
193    /// Create a sort node with descending flags
194    pub fn sort_with_order(id: usize, keys: Vec<String>, descending: Vec<bool>) -> Self {
195        Self::new(id, OperatorType::Sort { keys, descending })
196    }
197
198    /// Create a limit node
199    pub fn limit(id: usize, count: usize) -> Self {
200        Self::new(id, OperatorType::Limit { count })
201    }
202
203    /// Create an aggregate node
204    pub fn aggregate(id: usize, functions: Vec<String>) -> Self {
205        Self::new(id, OperatorType::Aggregate { functions })
206    }
207
208    /// Create a group by node
209    pub fn group_by(id: usize, keys: Vec<String>) -> Self {
210        Self::new(id, OperatorType::GroupBy { keys })
211    }
212
213    /// Create a vector distance node
214    pub fn vector_distance(id: usize, metric: &str) -> Self {
215        Self::new(
216            id,
217            OperatorType::VectorDistance {
218                metric: metric.to_string(),
219            },
220        )
221    }
222
223    /// Create a rerank node
224    pub fn rerank(id: usize, model: &str) -> Self {
225        Self::new(
226            id,
227            OperatorType::Rerank {
228                model: model.to_string(),
229            },
230        )
231    }
232
233    /// Create a materialize node
234    pub fn materialize(id: usize) -> Self {
235        Self::new(id, OperatorType::Materialize)
236    }
237
238    /// Create a result node
239    pub fn result(id: usize) -> Self {
240        Self::new(id, OperatorType::Result)
241    }
242
243    /// Set estimated statistics
244    pub fn with_estimates(mut self, rows: f64, cost: f64) -> Self {
245        self.estimated_rows = rows;
246        self.estimated_cost = cost;
247        self
248    }
249
250    /// Set actual statistics
251    pub fn with_actuals(mut self, rows: f64, time_ms: f64) -> Self {
252        self.actual_rows = Some(rows);
253        self.actual_time_ms = Some(time_ms);
254        self
255    }
256
257    /// Set embedding vector
258    pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
259        self.embedding = Some(embedding);
260        self
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_operator_node_creation() {
270        let node = OperatorNode::seq_scan(1, "users");
271        assert_eq!(node.id, 1);
272        assert!(matches!(node.op_type, OperatorType::SeqScan { .. }));
273    }
274
275    #[test]
276    fn test_builder_pattern() {
277        let node = OperatorNode::hash_join(2, "id")
278            .with_estimates(1000.0, 50.0)
279            .with_actuals(987.0, 45.2);
280
281        assert_eq!(node.estimated_rows, 1000.0);
282        assert_eq!(node.estimated_cost, 50.0);
283        assert_eq!(node.actual_rows, Some(987.0));
284        assert_eq!(node.actual_time_ms, Some(45.2));
285    }
286
287    #[test]
288    fn test_serialization() {
289        let node = OperatorNode::hnsw_scan(3, "embeddings_idx", 100);
290        let json = serde_json::to_string(&node).unwrap();
291        let deserialized: OperatorNode = serde_json::from_str(&json).unwrap();
292        assert_eq!(node.id, deserialized.id);
293    }
294}