sql_cli/sql/functions/
comparison.rs

1use anyhow::{anyhow, Result};
2use std::cmp::Ordering;
3
4use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5use crate::data::datatable::DataValue;
6
7/// Helper to compare two `DataValues`
8/// Returns None if values are incomparable (different types that can't be coerced)
9fn compare_values(a: &DataValue, b: &DataValue) -> Option<Ordering> {
10    match (a, b) {
11        // Null handling - NULL is considered smallest
12        (DataValue::Null, DataValue::Null) => Some(Ordering::Equal),
13        (DataValue::Null, _) => Some(Ordering::Less),
14        (_, DataValue::Null) => Some(Ordering::Greater),
15
16        // Same type comparisons
17        (DataValue::Integer(x), DataValue::Integer(y)) => Some(x.cmp(y)),
18        (DataValue::Float(x), DataValue::Float(y)) => {
19            // Handle NaN values
20            if x.is_nan() && y.is_nan() {
21                Some(Ordering::Equal)
22            } else if x.is_nan() {
23                Some(Ordering::Less) // NaN is treated as smallest
24            } else if y.is_nan() {
25                Some(Ordering::Greater)
26            } else {
27                x.partial_cmp(y)
28            }
29        }
30        (DataValue::String(x), DataValue::String(y)) => Some(x.cmp(y)),
31        (DataValue::InternedString(x), DataValue::InternedString(y)) => Some(x.cmp(y)),
32        (DataValue::String(x), DataValue::InternedString(y))
33        | (DataValue::InternedString(y), DataValue::String(x)) => Some(x.as_str().cmp(y.as_str())),
34        (DataValue::Boolean(x), DataValue::Boolean(y)) => Some(x.cmp(y)),
35        (DataValue::DateTime(x), DataValue::DateTime(y)) => Some(x.cmp(y)),
36
37        // Numeric coercion - allow comparing integers and floats
38        (DataValue::Integer(x), DataValue::Float(y)) => {
39            let x_float = *x as f64;
40            if y.is_nan() {
41                Some(Ordering::Greater)
42            } else {
43                x_float.partial_cmp(y)
44            }
45        }
46        (DataValue::Float(x), DataValue::Integer(y)) => {
47            let y_float = *y as f64;
48            if x.is_nan() {
49                Some(Ordering::Less)
50            } else {
51                x.partial_cmp(&y_float)
52            }
53        }
54
55        // Different types that can't be compared
56        _ => None,
57    }
58}
59
60/// GREATEST function - returns the maximum value from a list of values
61pub struct GreatestFunction;
62
63impl SqlFunction for GreatestFunction {
64    fn signature(&self) -> FunctionSignature {
65        FunctionSignature {
66            name: "GREATEST",
67            category: FunctionCategory::Mathematical,
68            arg_count: ArgCount::Variadic,
69            description: "Returns the greatest value from a list of values",
70            returns: "ANY",
71            examples: vec![
72                "SELECT GREATEST(10, 20, 5)",
73                "SELECT GREATEST(salary, bonus, commission) as max_pay FROM employees",
74                "SELECT GREATEST('apple', 'banana', 'cherry')",
75                "SELECT GREATEST(date1, date2, date3) as latest_date",
76            ],
77        }
78    }
79
80    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
81        if args.is_empty() {
82            return Err(anyhow!("GREATEST requires at least one argument"));
83        }
84
85        // Start with the first non-null value, or return NULL if all are NULL
86        let mut greatest = None;
87
88        for arg in args {
89            match &greatest {
90                None => {
91                    // First value or all previous were NULL
92                    if !matches!(arg, DataValue::Null) {
93                        greatest = Some(arg.clone());
94                    }
95                }
96                Some(current) => {
97                    // Skip NULL values when we already have a non-null value
98                    if matches!(arg, DataValue::Null) {
99                        continue;
100                    }
101                    // Compare with current greatest
102                    match compare_values(arg, current) {
103                        Some(Ordering::Greater) => {
104                            greatest = Some(arg.clone());
105                        }
106                        Some(_) => {
107                            // Keep current greatest
108                        }
109                        None => {
110                            // Type mismatch - can't compare
111                            return Err(anyhow!(
112                                "GREATEST: Cannot compare values of different types: {:?} and {:?}",
113                                current,
114                                arg
115                            ));
116                        }
117                    }
118                }
119            }
120        }
121
122        // If all values were NULL, return NULL
123        Ok(greatest.unwrap_or(DataValue::Null))
124    }
125}
126
127/// LEAST function - returns the minimum value from a list of values
128pub struct LeastFunction;
129
130impl SqlFunction for LeastFunction {
131    fn signature(&self) -> FunctionSignature {
132        FunctionSignature {
133            name: "LEAST",
134            category: FunctionCategory::Mathematical,
135            arg_count: ArgCount::Variadic,
136            description: "Returns the smallest value from a list of values",
137            returns: "ANY",
138            examples: vec![
139                "SELECT LEAST(10, 20, 5)",
140                "SELECT LEAST(salary, min_wage) as lower_bound FROM employees",
141                "SELECT LEAST('apple', 'banana', 'cherry')",
142                "SELECT LEAST(date1, date2, date3) as earliest_date",
143            ],
144        }
145    }
146
147    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
148        if args.is_empty() {
149            return Err(anyhow!("LEAST requires at least one argument"));
150        }
151
152        // Start with the first non-null value, or return NULL if all are NULL
153        let mut least = None;
154
155        for arg in args {
156            match &least {
157                None => {
158                    // First value or all previous were NULL
159                    if !matches!(arg, DataValue::Null) {
160                        least = Some(arg.clone());
161                    }
162                }
163                Some(current) => {
164                    // Skip NULL values when we already have a non-null value
165                    if matches!(arg, DataValue::Null) {
166                        continue;
167                    }
168                    // Compare with current least
169                    match compare_values(arg, current) {
170                        Some(Ordering::Less) => {
171                            least = Some(arg.clone());
172                        }
173                        Some(_) => {
174                            // Keep current least
175                        }
176                        None => {
177                            // Type mismatch - can't compare
178                            return Err(anyhow!(
179                                "LEAST: Cannot compare values of different types: {:?} and {:?}",
180                                current,
181                                arg
182                            ));
183                        }
184                    }
185                }
186            }
187        }
188
189        // If all values were NULL, return NULL
190        Ok(least.unwrap_or(DataValue::Null))
191    }
192}
193
194/// COALESCE function - returns the first non-null value
195pub struct CoalesceFunction;
196
197impl SqlFunction for CoalesceFunction {
198    fn signature(&self) -> FunctionSignature {
199        FunctionSignature {
200            name: "COALESCE",
201            category: FunctionCategory::Mathematical,
202            arg_count: ArgCount::Variadic,
203            description: "Returns the first non-null value from a list",
204            returns: "ANY",
205            examples: vec![
206                "SELECT COALESCE(NULL, 'default', 'backup')",
207                "SELECT COALESCE(phone, mobile, email) as contact FROM users",
208                "SELECT COALESCE(discount, 0) as discount_amount",
209            ],
210        }
211    }
212
213    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
214        if args.is_empty() {
215            return Err(anyhow!("COALESCE requires at least one argument"));
216        }
217
218        // Return the first non-null value
219        for arg in args {
220            if !matches!(arg, DataValue::Null) {
221                return Ok(arg.clone());
222            }
223        }
224
225        // All values were NULL
226        Ok(DataValue::Null)
227    }
228}
229
230/// NULLIF function - returns NULL if two values are equal
231pub struct NullIfFunction;
232
233impl SqlFunction for NullIfFunction {
234    fn signature(&self) -> FunctionSignature {
235        FunctionSignature {
236            name: "NULLIF",
237            category: FunctionCategory::Mathematical,
238            arg_count: ArgCount::Fixed(2),
239            description: "Returns NULL if two values are equal, otherwise returns the first value",
240            returns: "ANY",
241            examples: vec![
242                "SELECT NULLIF(0, 0)", // Returns NULL
243                "SELECT NULLIF(price, 0) as non_zero_price",
244                "SELECT NULLIF(status, 'DELETED') as active_status",
245            ],
246        }
247    }
248
249    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
250        self.validate_args(args)?;
251
252        let val1 = &args[0];
253        let val2 = &args[1];
254
255        // Check if values are equal
256        match compare_values(val1, val2) {
257            Some(Ordering::Equal) => Ok(DataValue::Null),
258            Some(_) => Ok(val1.clone()),
259            None => {
260                // Different types - they can't be equal
261                Ok(val1.clone())
262            }
263        }
264    }
265}
266
267/// IIF function - immediate if (if-then-else)
268pub struct IifFunction;
269
270impl SqlFunction for IifFunction {
271    fn signature(&self) -> FunctionSignature {
272        FunctionSignature {
273            name: "IIF",
274            category: FunctionCategory::Mathematical,
275            arg_count: ArgCount::Fixed(3),
276            description: "Returns second argument if first is true, third if false",
277            returns: "ANY",
278            examples: vec![
279                "SELECT IIF(1 > 0, 'positive', 'negative')",
280                "SELECT IIF(MASS_SUN() > MASS_EARTH(), 'sun', 'earth') as bigger",
281                "SELECT IIF(price > 100, 'expensive', 'affordable') as price_category",
282            ],
283        }
284    }
285
286    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
287        self.validate_args(args)?;
288
289        let condition = &args[0];
290        let true_value = &args[1];
291        let false_value = &args[2];
292
293        // Evaluate condition as boolean
294        let is_true = match condition {
295            DataValue::Boolean(b) => *b,
296            DataValue::Integer(i) => *i != 0,
297            DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
298            DataValue::String(s) => !s.is_empty(),
299            DataValue::InternedString(s) => !s.is_empty(),
300            DataValue::Null => false,
301            _ => false,
302        };
303
304        Ok(if is_true {
305            true_value.clone()
306        } else {
307            false_value.clone()
308        })
309    }
310}
311
312/// `GREATEST_LABEL` function - returns the label associated with the greatest value
313/// Takes pairs of (label, value) and returns the label of the maximum value
314pub struct GreatestLabelFunction;
315
316impl SqlFunction for GreatestLabelFunction {
317    fn signature(&self) -> FunctionSignature {
318        FunctionSignature {
319            name: "GREATEST_LABEL",
320            category: FunctionCategory::Mathematical,
321            arg_count: ArgCount::Variadic,
322            description:
323                "Returns the label associated with the greatest value from label/value pairs",
324            returns: "STRING",
325            examples: vec![
326                "SELECT GREATEST_LABEL('earth', MASS_EARTH(), 'sun', MASS_SUN()) as bigger_body",
327                "SELECT GREATEST_LABEL('jan', 100, 'feb', 150, 'mar', 120) as best_month",
328                "SELECT GREATEST_LABEL('product_a', sales_a, 'product_b', sales_b) as top_product",
329            ],
330        }
331    }
332
333    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
334        if args.is_empty() {
335            return Err(anyhow!(
336                "GREATEST_LABEL requires at least one label/value pair"
337            ));
338        }
339
340        if args.len() % 2 != 0 {
341            return Err(anyhow!(
342                "GREATEST_LABEL requires an even number of arguments (label/value pairs)"
343            ));
344        }
345
346        let mut best_label = None;
347        let mut best_value = None;
348
349        // Process pairs of (label, value)
350        for i in (0..args.len()).step_by(2) {
351            let label = &args[i];
352            let value = &args[i + 1];
353
354            // Skip if value is NULL
355            if matches!(value, DataValue::Null) {
356                continue;
357            }
358
359            match &best_value {
360                None => {
361                    // First non-null value
362                    best_label = Some(label.clone());
363                    best_value = Some(value.clone());
364                }
365                Some(current_best) => {
366                    // Compare with current best
367                    match compare_values(value, current_best) {
368                        Some(Ordering::Greater) => {
369                            best_label = Some(label.clone());
370                            best_value = Some(value.clone());
371                        }
372                        Some(_) => {
373                            // Keep current best
374                        }
375                        None => {
376                            // Type mismatch - can't compare
377                            return Err(anyhow!(
378                                "GREATEST_LABEL: Cannot compare values of different types: {:?} and {:?}",
379                                current_best,
380                                value
381                            ));
382                        }
383                    }
384                }
385            }
386        }
387
388        // Return the label of the greatest value, or NULL if all values were NULL
389        Ok(best_label.unwrap_or(DataValue::Null))
390    }
391}
392
393/// `LEAST_LABEL` function - returns the label associated with the smallest value
394pub struct LeastLabelFunction;
395
396impl SqlFunction for LeastLabelFunction {
397    fn signature(&self) -> FunctionSignature {
398        FunctionSignature {
399            name: "LEAST_LABEL",
400            category: FunctionCategory::Mathematical,
401            arg_count: ArgCount::Variadic,
402            description: "Returns the label associated with the smallest value from label/value pairs",
403            returns: "STRING",
404            examples: vec![
405                "SELECT LEAST_LABEL('mercury', MASS_MERCURY(), 'earth', MASS_EARTH()) as smaller_planet",
406                "SELECT LEAST_LABEL('jan', 100, 'feb', 150, 'mar', 120) as worst_month",
407                "SELECT LEAST_LABEL('cost_a', 50, 'cost_b', 30) as cheapest_option",
408            ],
409        }
410    }
411
412    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
413        if args.is_empty() {
414            return Err(anyhow!(
415                "LEAST_LABEL requires at least one label/value pair"
416            ));
417        }
418
419        if args.len() % 2 != 0 {
420            return Err(anyhow!(
421                "LEAST_LABEL requires an even number of arguments (label/value pairs)"
422            ));
423        }
424
425        let mut best_label = None;
426        let mut best_value = None;
427
428        // Process pairs of (label, value)
429        for i in (0..args.len()).step_by(2) {
430            let label = &args[i];
431            let value = &args[i + 1];
432
433            // Skip if value is NULL
434            if matches!(value, DataValue::Null) {
435                continue;
436            }
437
438            match &best_value {
439                None => {
440                    // First non-null value
441                    best_label = Some(label.clone());
442                    best_value = Some(value.clone());
443                }
444                Some(current_best) => {
445                    // Compare with current best
446                    match compare_values(value, current_best) {
447                        Some(Ordering::Less) => {
448                            best_label = Some(label.clone());
449                            best_value = Some(value.clone());
450                        }
451                        Some(_) => {
452                            // Keep current best
453                        }
454                        None => {
455                            // Type mismatch - can't compare
456                            return Err(anyhow!(
457                                "LEAST_LABEL: Cannot compare values of different types: {:?} and {:?}",
458                                current_best,
459                                value
460                            ));
461                        }
462                    }
463                }
464            }
465        }
466
467        // Return the label of the smallest value, or NULL if all values were NULL
468        Ok(best_label.unwrap_or(DataValue::Null))
469    }
470}
471
472/// Register all comparison functions
473pub fn register_comparison_functions(registry: &mut super::FunctionRegistry) {
474    registry.register(Box::new(GreatestFunction));
475    registry.register(Box::new(LeastFunction));
476    registry.register(Box::new(CoalesceFunction));
477    registry.register(Box::new(NullIfFunction));
478    registry.register(Box::new(IifFunction));
479    registry.register(Box::new(GreatestLabelFunction));
480    registry.register(Box::new(LeastLabelFunction));
481}
482
483#[cfg(test)]
484mod tests {
485    use super::*;
486
487    #[test]
488    fn test_greatest_with_integers() {
489        let func = GreatestFunction;
490        let args = vec![
491            DataValue::Integer(10),
492            DataValue::Integer(5),
493            DataValue::Integer(20),
494            DataValue::Integer(15),
495        ];
496
497        let result = func.evaluate(&args).unwrap();
498        assert_eq!(result, DataValue::Integer(20));
499    }
500
501    #[test]
502    fn test_greatest_with_floats() {
503        let func = GreatestFunction;
504        let args = vec![
505            DataValue::Float(10.5),
506            DataValue::Float(5.2),
507            DataValue::Float(20.8),
508            DataValue::Float(15.3),
509        ];
510
511        let result = func.evaluate(&args).unwrap();
512        assert_eq!(result, DataValue::Float(20.8));
513    }
514
515    #[test]
516    fn test_greatest_with_mixed_numbers() {
517        let func = GreatestFunction;
518        let args = vec![
519            DataValue::Integer(10),
520            DataValue::Float(15.5),
521            DataValue::Integer(20),
522            DataValue::Float(12.3),
523        ];
524
525        let result = func.evaluate(&args).unwrap();
526        assert_eq!(result, DataValue::Integer(20));
527    }
528
529    #[test]
530    fn test_greatest_with_nulls() {
531        let func = GreatestFunction;
532        let args = vec![
533            DataValue::Null,
534            DataValue::Integer(10),
535            DataValue::Null,
536            DataValue::Integer(5),
537        ];
538
539        let result = func.evaluate(&args).unwrap();
540        assert_eq!(result, DataValue::Integer(10));
541    }
542
543    #[test]
544    fn test_greatest_all_nulls() {
545        let func = GreatestFunction;
546        let args = vec![DataValue::Null, DataValue::Null, DataValue::Null];
547
548        let result = func.evaluate(&args).unwrap();
549        assert_eq!(result, DataValue::Null);
550    }
551
552    #[test]
553    fn test_greatest_with_strings() {
554        let func = GreatestFunction;
555        let args = vec![
556            DataValue::String("apple".to_string()),
557            DataValue::String("banana".to_string()),
558            DataValue::String("cherry".to_string()),
559        ];
560
561        let result = func.evaluate(&args).unwrap();
562        assert_eq!(result, DataValue::String("cherry".to_string()));
563    }
564
565    #[test]
566    fn test_least_with_integers() {
567        let func = LeastFunction;
568        let args = vec![
569            DataValue::Integer(10),
570            DataValue::Integer(5),
571            DataValue::Integer(20),
572            DataValue::Integer(15),
573        ];
574
575        let result = func.evaluate(&args).unwrap();
576        assert_eq!(result, DataValue::Integer(5));
577    }
578
579    #[test]
580    fn test_least_with_nulls() {
581        let func = LeastFunction;
582        let args = vec![
583            DataValue::Integer(10),
584            DataValue::Null,
585            DataValue::Integer(5),
586            DataValue::Integer(20),
587        ];
588
589        let result = func.evaluate(&args).unwrap();
590        assert_eq!(result, DataValue::Integer(5));
591    }
592
593    #[test]
594    fn test_coalesce() {
595        let func = CoalesceFunction;
596        let args = vec![
597            DataValue::Null,
598            DataValue::Null,
599            DataValue::String("first".to_string()),
600            DataValue::String("second".to_string()),
601        ];
602
603        let result = func.evaluate(&args).unwrap();
604        assert_eq!(result, DataValue::String("first".to_string()));
605    }
606
607    #[test]
608    fn test_nullif_equal() {
609        let func = NullIfFunction;
610        let args = vec![DataValue::Integer(5), DataValue::Integer(5)];
611
612        let result = func.evaluate(&args).unwrap();
613        assert_eq!(result, DataValue::Null);
614    }
615
616    #[test]
617    fn test_nullif_not_equal() {
618        let func = NullIfFunction;
619        let args = vec![DataValue::Integer(5), DataValue::Integer(10)];
620
621        let result = func.evaluate(&args).unwrap();
622        assert_eq!(result, DataValue::Integer(5));
623    }
624}