ruvector_graph/optimization/
query_jit.rs1use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10pub struct JitCompiler {
12 compiled_cache: Arc<RwLock<HashMap<String, Arc<JitQuery>>>>,
14 stats: Arc<RwLock<QueryStats>>,
16}
17
18impl JitCompiler {
19 pub fn new() -> Self {
20 Self {
21 compiled_cache: Arc::new(RwLock::new(HashMap::new())),
22 stats: Arc::new(RwLock::new(QueryStats::new())),
23 }
24 }
25
26 pub fn compile(&self, pattern: &str) -> Arc<JitQuery> {
28 {
30 let cache = self.compiled_cache.read();
31 if let Some(compiled) = cache.get(pattern) {
32 return Arc::clone(compiled);
33 }
34 }
35
36 let query = Arc::new(self.compile_pattern(pattern));
38
39 self.compiled_cache
41 .write()
42 .insert(pattern.to_string(), Arc::clone(&query));
43
44 query
45 }
46
47 fn compile_pattern(&self, pattern: &str) -> JitQuery {
49 let operators = self.parse_and_optimize(pattern);
51
52 JitQuery {
53 pattern: pattern.to_string(),
54 operators,
55 }
56 }
57
58 fn parse_and_optimize(&self, pattern: &str) -> Vec<QueryOperator> {
60 let mut operators = Vec::new();
61
62 if pattern.contains("MATCH") && pattern.contains("WHERE") {
64 operators.push(QueryOperator::LabelScan {
66 label: "Label".to_string(),
67 });
68 operators.push(QueryOperator::Filter {
69 predicate: FilterPredicate::Equality {
70 property: "prop".to_string(),
71 value: PropertyValue::String("value".to_string()),
72 },
73 });
74 } else if pattern.contains("MATCH") && pattern.contains("->") {
75 operators.push(QueryOperator::Expand {
77 direction: Direction::Outgoing,
78 edge_label: None,
79 });
80 } else {
81 operators.push(QueryOperator::FullScan);
83 }
84
85 operators
86 }
87
88 pub fn record_execution(&self, pattern: &str, duration_ns: u64) {
90 self.stats.write().record(pattern, duration_ns);
91 }
92
93 pub fn get_hot_queries(&self, threshold: u64) -> Vec<String> {
95 self.stats.read().get_hot_queries(threshold)
96 }
97}
98
99impl Default for JitCompiler {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105pub struct JitQuery {
107 pub pattern: String,
109 pub operators: Vec<QueryOperator>,
111}
112
113impl JitQuery {
114 pub fn execute<F>(&self, mut executor: F) -> QueryResult
116 where
117 F: FnMut(&QueryOperator) -> IntermediateResult,
118 {
119 let mut result = IntermediateResult::default();
120
121 for operator in &self.operators {
122 result = executor(operator);
123 }
124
125 QueryResult {
126 nodes: result.nodes,
127 edges: result.edges,
128 }
129 }
130}
131
132#[derive(Debug, Clone)]
134pub enum QueryOperator {
135 FullScan,
137
138 LabelScan { label: String },
140
141 PropertyScan {
143 property: String,
144 value: PropertyValue,
145 },
146
147 Expand {
149 direction: Direction,
150 edge_label: Option<String>,
151 },
152
153 Filter { predicate: FilterPredicate },
155
156 Project { properties: Vec<String> },
158
159 Aggregate { function: AggregateFunction },
161
162 Sort { property: String, ascending: bool },
164
165 Limit { count: usize },
167}
168
169#[derive(Debug, Clone)]
170pub enum Direction {
171 Incoming,
172 Outgoing,
173 Both,
174}
175
176#[derive(Debug, Clone)]
177pub enum FilterPredicate {
178 Equality {
179 property: String,
180 value: PropertyValue,
181 },
182 Range {
183 property: String,
184 min: PropertyValue,
185 max: PropertyValue,
186 },
187 Regex {
188 property: String,
189 pattern: String,
190 },
191}
192
193#[derive(Debug, Clone)]
194pub enum PropertyValue {
195 String(String),
196 Integer(i64),
197 Float(f64),
198 Boolean(bool),
199}
200
201#[derive(Debug, Clone)]
202pub enum AggregateFunction {
203 Count,
204 Sum { property: String },
205 Avg { property: String },
206 Min { property: String },
207 Max { property: String },
208}
209
210#[derive(Default)]
212pub struct IntermediateResult {
213 pub nodes: Vec<u64>,
214 pub edges: Vec<(u64, u64)>,
215}
216
217pub struct QueryResult {
219 pub nodes: Vec<u64>,
220 pub edges: Vec<(u64, u64)>,
221}
222
223struct QueryStats {
225 execution_counts: HashMap<String, u64>,
227 total_time_ns: HashMap<String, u64>,
229}
230
231impl QueryStats {
232 fn new() -> Self {
233 Self {
234 execution_counts: HashMap::new(),
235 total_time_ns: HashMap::new(),
236 }
237 }
238
239 fn record(&mut self, pattern: &str, duration_ns: u64) {
240 *self
241 .execution_counts
242 .entry(pattern.to_string())
243 .or_insert(0) += 1;
244 *self.total_time_ns.entry(pattern.to_string()).or_insert(0) += duration_ns;
245 }
246
247 fn get_hot_queries(&self, threshold: u64) -> Vec<String> {
248 self.execution_counts
249 .iter()
250 .filter(|(_, &count)| count >= threshold)
251 .map(|(pattern, _)| pattern.clone())
252 .collect()
253 }
254
255 fn avg_time_ns(&self, pattern: &str) -> Option<u64> {
256 let count = self.execution_counts.get(pattern)?;
257 let total = self.total_time_ns.get(pattern)?;
258
259 if *count > 0 {
260 Some(total / count)
261 } else {
262 None
263 }
264 }
265}
266
267pub mod specialized_ops {
269 use super::*;
270
271 pub fn vectorized_label_scan(label: &str, nodes: &[u64]) -> Vec<u64> {
273 nodes.iter().copied().collect()
275 }
276
277 pub fn vectorized_property_filter(
279 property: &str,
280 predicate: &FilterPredicate,
281 nodes: &[u64],
282 ) -> Vec<u64> {
283 nodes.iter().copied().collect()
285 }
286
287 pub fn cache_friendly_expand(nodes: &[u64], direction: Direction) -> Vec<(u64, u64)> {
289 Vec::new()
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_jit_compiler() {
300 let compiler = JitCompiler::new();
301
302 let query = compiler.compile("MATCH (n:Person) WHERE n.age > 18");
303 assert!(!query.operators.is_empty());
304 }
305
306 #[test]
307 fn test_query_stats() {
308 let compiler = JitCompiler::new();
309
310 compiler.record_execution("MATCH (n)", 1000);
311 compiler.record_execution("MATCH (n)", 2000);
312 compiler.record_execution("MATCH (n)", 3000);
313
314 let hot = compiler.get_hot_queries(2);
315 assert_eq!(hot.len(), 1);
316 assert_eq!(hot[0], "MATCH (n)");
317 }
318
319 #[test]
320 fn test_operator_chain() {
321 let operators = vec![
322 QueryOperator::LabelScan {
323 label: "Person".to_string(),
324 },
325 QueryOperator::Filter {
326 predicate: FilterPredicate::Range {
327 property: "age".to_string(),
328 min: PropertyValue::Integer(18),
329 max: PropertyValue::Integer(65),
330 },
331 },
332 QueryOperator::Limit { count: 10 },
333 ];
334
335 assert_eq!(operators.len(), 3);
336 }
337}