1use crate::executor::pipeline::RowBatch;
6use crate::executor::plan::{Predicate, Value};
7use crate::executor::{ExecutionError, Result};
8use std::collections::HashMap;
9
10#[cfg(target_arch = "x86_64")]
11use std::arch::x86_64::*;
12
13pub trait Operator: Send + Sync {
15 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>>;
17
18 fn name(&self) -> &str;
20
21 fn is_pipeline_breaker(&self) -> bool {
23 false
24 }
25}
26
27#[derive(Debug, Clone)]
29pub enum ScanMode {
30 Sequential,
32 Index { index_name: String },
34 Range { start: Value, end: Value },
36}
37
38pub struct NodeScan {
40 mode: ScanMode,
41 filter: Option<Predicate>,
42 position: usize,
43}
44
45impl NodeScan {
46 pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
47 Self {
48 mode,
49 filter,
50 position: 0,
51 }
52 }
53}
54
55impl Operator for NodeScan {
56 fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
57 Ok(None)
60 }
61
62 fn name(&self) -> &str {
63 "NodeScan"
64 }
65}
66
67pub struct EdgeScan {
69 mode: ScanMode,
70 filter: Option<Predicate>,
71 position: usize,
72}
73
74impl EdgeScan {
75 pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
76 Self {
77 mode,
78 filter,
79 position: 0,
80 }
81 }
82}
83
84impl Operator for EdgeScan {
85 fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
86 Ok(None)
87 }
88
89 fn name(&self) -> &str {
90 "EdgeScan"
91 }
92}
93
94pub struct HyperedgeScan {
96 mode: ScanMode,
97 filter: Option<Predicate>,
98}
99
100impl HyperedgeScan {
101 pub fn new(mode: ScanMode, filter: Option<Predicate>) -> Self {
102 Self { mode, filter }
103 }
104}
105
106impl Operator for HyperedgeScan {
107 fn execute(&mut self, _input: Option<RowBatch>) -> Result<Option<RowBatch>> {
108 Ok(None)
109 }
110
111 fn name(&self) -> &str {
112 "HyperedgeScan"
113 }
114}
115
116pub struct Filter {
118 predicate: Predicate,
119}
120
121impl Filter {
122 pub fn new(predicate: Predicate) -> Self {
123 Self { predicate }
124 }
125
126 fn evaluate(&self, row: &HashMap<String, Value>) -> bool {
128 self.evaluate_predicate(&self.predicate, row)
129 }
130
131 fn evaluate_predicate(&self, pred: &Predicate, row: &HashMap<String, Value>) -> bool {
132 match pred {
133 Predicate::Equals(col, val) => row.get(col).map(|v| v == val).unwrap_or(false),
134 Predicate::NotEquals(col, val) => row.get(col).map(|v| v != val).unwrap_or(false),
135 Predicate::GreaterThan(col, val) => row
136 .get(col)
137 .and_then(|v| v.compare(val))
138 .map(|ord| ord == std::cmp::Ordering::Greater)
139 .unwrap_or(false),
140 Predicate::GreaterThanOrEqual(col, val) => row
141 .get(col)
142 .and_then(|v| v.compare(val))
143 .map(|ord| ord != std::cmp::Ordering::Less)
144 .unwrap_or(false),
145 Predicate::LessThan(col, val) => row
146 .get(col)
147 .and_then(|v| v.compare(val))
148 .map(|ord| ord == std::cmp::Ordering::Less)
149 .unwrap_or(false),
150 Predicate::LessThanOrEqual(col, val) => row
151 .get(col)
152 .and_then(|v| v.compare(val))
153 .map(|ord| ord != std::cmp::Ordering::Greater)
154 .unwrap_or(false),
155 Predicate::In(col, values) => row.get(col).map(|v| values.contains(v)).unwrap_or(false),
156 Predicate::Like(col, pattern) => {
157 if let Some(Value::String(s)) = row.get(col) {
158 self.pattern_match(s, pattern)
159 } else {
160 false
161 }
162 }
163 Predicate::And(preds) => preds.iter().all(|p| self.evaluate_predicate(p, row)),
164 Predicate::Or(preds) => preds.iter().any(|p| self.evaluate_predicate(p, row)),
165 Predicate::Not(pred) => !self.evaluate_predicate(pred, row),
166 }
167 }
168
169 fn pattern_match(&self, s: &str, pattern: &str) -> bool {
170 if pattern.starts_with('%') && pattern.ends_with('%') {
172 let p = &pattern[1..pattern.len() - 1];
173 s.contains(p)
174 } else if pattern.starts_with('%') {
175 let p = &pattern[1..];
176 s.ends_with(p)
177 } else if pattern.ends_with('%') {
178 let p = &pattern[..pattern.len() - 1];
179 s.starts_with(p)
180 } else {
181 s == pattern
182 }
183 }
184
185 #[cfg(target_arch = "x86_64")]
187 fn filter_batch_simd(&self, values: &[f32], threshold: f32) -> Vec<bool> {
188 if is_x86_feature_detected!("avx2") {
189 unsafe { self.filter_batch_avx2(values, threshold) }
190 } else {
191 self.filter_batch_scalar(values, threshold)
192 }
193 }
194
195 #[cfg(target_arch = "x86_64")]
196 #[target_feature(enable = "avx2")]
197 unsafe fn filter_batch_avx2(&self, values: &[f32], threshold: f32) -> Vec<bool> {
198 let mut result = vec![false; values.len()];
199 let threshold_vec = _mm256_set1_ps(threshold);
200
201 let chunks = values.len() / 8;
202 for i in 0..chunks {
203 let idx = i * 8;
204 let vals = _mm256_loadu_ps(values.as_ptr().add(idx));
205 let cmp = _mm256_cmp_ps(vals, threshold_vec, _CMP_GT_OQ);
206
207 let mask: [f32; 8] = std::mem::transmute(cmp);
208 for j in 0..8 {
209 result[idx + j] = mask[j] != 0.0;
210 }
211 }
212
213 for i in (chunks * 8)..values.len() {
215 result[i] = values[i] > threshold;
216 }
217
218 result
219 }
220
221 #[cfg(not(target_arch = "x86_64"))]
222 fn filter_batch_simd(&self, values: &[f32], threshold: f32) -> Vec<bool> {
223 self.filter_batch_scalar(values, threshold)
224 }
225
226 fn filter_batch_scalar(&self, values: &[f32], threshold: f32) -> Vec<bool> {
227 values.iter().map(|&v| v > threshold).collect()
228 }
229}
230
231impl Operator for Filter {
232 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
233 if let Some(batch) = input {
234 let filtered_rows: Vec<_> = batch
235 .rows
236 .into_iter()
237 .filter(|row| self.evaluate(row))
238 .collect();
239
240 Ok(Some(RowBatch {
241 rows: filtered_rows,
242 schema: batch.schema,
243 }))
244 } else {
245 Ok(None)
246 }
247 }
248
249 fn name(&self) -> &str {
250 "Filter"
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq)]
256pub enum JoinType {
257 Inner,
258 LeftOuter,
259 RightOuter,
260 FullOuter,
261}
262
263pub struct Join {
265 join_type: JoinType,
266 on: Vec<(String, String)>,
267 hash_table: HashMap<Vec<Value>, Vec<HashMap<String, Value>>>,
268 built: bool,
269}
270
271impl Join {
272 pub fn new(join_type: JoinType, on: Vec<(String, String)>) -> Self {
273 Self {
274 join_type,
275 on,
276 hash_table: HashMap::new(),
277 built: false,
278 }
279 }
280
281 fn build_hash_table(&mut self, build_side: RowBatch) {
282 for row in build_side.rows {
283 let key: Vec<Value> = self
284 .on
285 .iter()
286 .filter_map(|(_, right_col)| row.get(right_col).cloned())
287 .collect();
288
289 self.hash_table
290 .entry(key)
291 .or_insert_with(Vec::new)
292 .push(row);
293 }
294 self.built = true;
295 }
296
297 fn probe(&self, probe_row: &HashMap<String, Value>) -> Vec<HashMap<String, Value>> {
298 let key: Vec<Value> = self
299 .on
300 .iter()
301 .filter_map(|(left_col, _)| probe_row.get(left_col).cloned())
302 .collect();
303
304 if let Some(matches) = self.hash_table.get(&key) {
305 matches
306 .iter()
307 .map(|right_row| {
308 let mut joined = probe_row.clone();
309 joined.extend(right_row.clone());
310 joined
311 })
312 .collect()
313 } else {
314 Vec::new()
315 }
316 }
317}
318
319impl Operator for Join {
320 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
321 Ok(None)
323 }
324
325 fn name(&self) -> &str {
326 "Join"
327 }
328
329 fn is_pipeline_breaker(&self) -> bool {
330 true }
332}
333
334#[derive(Debug, Clone, Copy, PartialEq)]
336pub enum AggregateFunction {
337 Count,
338 Sum,
339 Avg,
340 Min,
341 Max,
342}
343
344pub struct Aggregate {
346 group_by: Vec<String>,
347 aggregates: Vec<(AggregateFunction, String)>,
348 state: HashMap<Vec<Value>, Vec<f64>>,
349}
350
351impl Aggregate {
352 pub fn new(group_by: Vec<String>, aggregates: Vec<(AggregateFunction, String)>) -> Self {
353 Self {
354 group_by,
355 aggregates,
356 state: HashMap::new(),
357 }
358 }
359}
360
361impl Operator for Aggregate {
362 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
363 Ok(None)
364 }
365
366 fn name(&self) -> &str {
367 "Aggregate"
368 }
369
370 fn is_pipeline_breaker(&self) -> bool {
371 true
372 }
373}
374
375pub struct Project {
377 columns: Vec<String>,
378}
379
380impl Project {
381 pub fn new(columns: Vec<String>) -> Self {
382 Self { columns }
383 }
384}
385
386impl Operator for Project {
387 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
388 if let Some(batch) = input {
389 let projected: Vec<_> = batch
390 .rows
391 .into_iter()
392 .map(|row| {
393 self.columns
394 .iter()
395 .filter_map(|col| row.get(col).map(|v| (col.clone(), v.clone())))
396 .collect()
397 })
398 .collect();
399
400 Ok(Some(RowBatch {
401 rows: projected,
402 schema: batch.schema,
403 }))
404 } else {
405 Ok(None)
406 }
407 }
408
409 fn name(&self) -> &str {
410 "Project"
411 }
412}
413
414pub struct Sort {
416 order_by: Vec<(String, crate::executor::plan::SortOrder)>,
417 buffer: Vec<HashMap<String, Value>>,
418}
419
420impl Sort {
421 pub fn new(order_by: Vec<(String, crate::executor::plan::SortOrder)>) -> Self {
422 Self {
423 order_by,
424 buffer: Vec::new(),
425 }
426 }
427}
428
429impl Operator for Sort {
430 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
431 Ok(None)
432 }
433
434 fn name(&self) -> &str {
435 "Sort"
436 }
437
438 fn is_pipeline_breaker(&self) -> bool {
439 true
440 }
441}
442
443pub struct Limit {
445 limit: usize,
446 offset: usize,
447 current: usize,
448}
449
450impl Limit {
451 pub fn new(limit: usize, offset: usize) -> Self {
452 Self {
453 limit,
454 offset,
455 current: 0,
456 }
457 }
458}
459
460impl Operator for Limit {
461 fn execute(&mut self, input: Option<RowBatch>) -> Result<Option<RowBatch>> {
462 if let Some(batch) = input {
463 let start = self.offset.saturating_sub(self.current);
464 let end = start + self.limit;
465
466 let limited: Vec<_> = batch
467 .rows
468 .into_iter()
469 .skip(start)
470 .take(end - start)
471 .collect();
472
473 self.current += limited.len();
474
475 Ok(Some(RowBatch {
476 rows: limited,
477 schema: batch.schema,
478 }))
479 } else {
480 Ok(None)
481 }
482 }
483
484 fn name(&self) -> &str {
485 "Limit"
486 }
487}
488
489#[cfg(test)]
490mod tests {
491 use super::*;
492
493 #[test]
494 fn test_filter_operator() {
495 let mut filter = Filter::new(Predicate::Equals("id".to_string(), Value::Int64(42)));
496
497 let mut row = HashMap::new();
498 row.insert("id".to_string(), Value::Int64(42));
499 assert!(filter.evaluate(&row));
500 }
501
502 #[test]
503 fn test_pattern_matching() {
504 let filter = Filter::new(Predicate::Like("name".to_string(), "%test%".to_string()));
505 assert!(filter.pattern_match("this is a test", "%test%"));
506 }
507
508 #[test]
509 fn test_simd_filtering() {
510 let filter = Filter::new(Predicate::GreaterThan(
511 "value".to_string(),
512 Value::Float64(5.0),
513 ));
514 let values = vec![1.0, 6.0, 3.0, 8.0, 4.0, 9.0, 2.0, 7.0];
515 let result = filter.filter_batch_simd(&values, 5.0);
516 assert_eq!(
517 result,
518 vec![false, true, false, true, false, true, false, true]
519 );
520 }
521}