Skip to main content

sql_cli/sql/functions/
comparison.rs

1use anyhow::{anyhow, Result};
2use std::cmp::Ordering;
3
4use super::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
5use crate::data::datatable::DataValue;
6
7/// Helper to compare two `DataValues`
8/// Returns None if values are incomparable (different types that can't be coerced)
9fn compare_values(a: &DataValue, b: &DataValue) -> Option<Ordering> {
10    match (a, b) {
11        // Null handling - NULL is considered smallest
12        (DataValue::Null, DataValue::Null) => Some(Ordering::Equal),
13        (DataValue::Null, _) => Some(Ordering::Less),
14        (_, DataValue::Null) => Some(Ordering::Greater),
15
16        // Same type comparisons
17        (DataValue::Integer(x), DataValue::Integer(y)) => Some(x.cmp(y)),
18        (DataValue::Float(x), DataValue::Float(y)) => {
19            // Handle NaN values
20            if x.is_nan() && y.is_nan() {
21                Some(Ordering::Equal)
22            } else if x.is_nan() {
23                Some(Ordering::Less) // NaN is treated as smallest
24            } else if y.is_nan() {
25                Some(Ordering::Greater)
26            } else {
27                x.partial_cmp(y)
28            }
29        }
30        (DataValue::String(x), DataValue::String(y)) => Some(x.cmp(y)),
31        (DataValue::InternedString(x), DataValue::InternedString(y)) => Some(x.cmp(y)),
32        (DataValue::String(x), DataValue::InternedString(y))
33        | (DataValue::InternedString(y), DataValue::String(x)) => Some(x.as_str().cmp(y.as_str())),
34        (DataValue::Boolean(x), DataValue::Boolean(y)) => Some(x.cmp(y)),
35        (DataValue::DateTime(x), DataValue::DateTime(y)) => Some(x.cmp(y)),
36
37        // Numeric coercion - allow comparing integers and floats
38        (DataValue::Integer(x), DataValue::Float(y)) => {
39            let x_float = *x as f64;
40            if y.is_nan() {
41                Some(Ordering::Greater)
42            } else {
43                x_float.partial_cmp(y)
44            }
45        }
46        (DataValue::Float(x), DataValue::Integer(y)) => {
47            let y_float = *y as f64;
48            if x.is_nan() {
49                Some(Ordering::Less)
50            } else {
51                x.partial_cmp(&y_float)
52            }
53        }
54
55        // Different types that can't be compared
56        _ => None,
57    }
58}
59
60/// GREATEST function - returns the maximum value from a list of values
61pub struct GreatestFunction;
62
63impl SqlFunction for GreatestFunction {
64    fn signature(&self) -> FunctionSignature {
65        FunctionSignature {
66            name: "GREATEST",
67            category: FunctionCategory::Mathematical,
68            arg_count: ArgCount::Variadic,
69            description: "Returns the greatest value from a list of values",
70            returns: "ANY",
71            examples: vec![
72                "SELECT GREATEST(10, 20, 5)",
73                "SELECT GREATEST(salary, bonus, commission) as max_pay FROM employees",
74                "SELECT GREATEST('apple', 'banana', 'cherry')",
75                "SELECT GREATEST(date1, date2, date3) as latest_date",
76            ],
77        }
78    }
79
80    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
81        if args.is_empty() {
82            return Err(anyhow!("GREATEST requires at least one argument"));
83        }
84
85        // Start with the first non-null value, or return NULL if all are NULL
86        let mut greatest = None;
87
88        for arg in args {
89            match &greatest {
90                None => {
91                    // First value or all previous were NULL
92                    if !matches!(arg, DataValue::Null) {
93                        greatest = Some(arg.clone());
94                    }
95                }
96                Some(current) => {
97                    // Skip NULL values when we already have a non-null value
98                    if matches!(arg, DataValue::Null) {
99                        continue;
100                    }
101                    // Compare with current greatest
102                    match compare_values(arg, current) {
103                        Some(Ordering::Greater) => {
104                            greatest = Some(arg.clone());
105                        }
106                        Some(_) => {
107                            // Keep current greatest
108                        }
109                        None => {
110                            // Type mismatch - can't compare
111                            return Err(anyhow!(
112                                "GREATEST: Cannot compare values of different types: {:?} and {:?}",
113                                current,
114                                arg
115                            ));
116                        }
117                    }
118                }
119            }
120        }
121
122        // If all values were NULL, return NULL
123        Ok(greatest.unwrap_or(DataValue::Null))
124    }
125}
126
127/// LEAST function - returns the minimum value from a list of values
128pub struct LeastFunction;
129
130impl SqlFunction for LeastFunction {
131    fn signature(&self) -> FunctionSignature {
132        FunctionSignature {
133            name: "LEAST",
134            category: FunctionCategory::Mathematical,
135            arg_count: ArgCount::Variadic,
136            description: "Returns the smallest value from a list of values",
137            returns: "ANY",
138            examples: vec![
139                "SELECT LEAST(10, 20, 5)",
140                "SELECT LEAST(salary, min_wage) as lower_bound FROM employees",
141                "SELECT LEAST('apple', 'banana', 'cherry')",
142                "SELECT LEAST(date1, date2, date3) as earliest_date",
143            ],
144        }
145    }
146
147    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
148        if args.is_empty() {
149            return Err(anyhow!("LEAST requires at least one argument"));
150        }
151
152        // Start with the first non-null value, or return NULL if all are NULL
153        let mut least = None;
154
155        for arg in args {
156            match &least {
157                None => {
158                    // First value or all previous were NULL
159                    if !matches!(arg, DataValue::Null) {
160                        least = Some(arg.clone());
161                    }
162                }
163                Some(current) => {
164                    // Skip NULL values when we already have a non-null value
165                    if matches!(arg, DataValue::Null) {
166                        continue;
167                    }
168                    // Compare with current least
169                    match compare_values(arg, current) {
170                        Some(Ordering::Less) => {
171                            least = Some(arg.clone());
172                        }
173                        Some(_) => {
174                            // Keep current least
175                        }
176                        None => {
177                            // Type mismatch - can't compare
178                            return Err(anyhow!(
179                                "LEAST: Cannot compare values of different types: {:?} and {:?}",
180                                current,
181                                arg
182                            ));
183                        }
184                    }
185                }
186            }
187        }
188
189        // If all values were NULL, return NULL
190        Ok(least.unwrap_or(DataValue::Null))
191    }
192}
193
194/// COALESCE function - returns the first non-null value
195pub struct CoalesceFunction;
196
197impl SqlFunction for CoalesceFunction {
198    fn signature(&self) -> FunctionSignature {
199        FunctionSignature {
200            name: "COALESCE",
201            category: FunctionCategory::Mathematical,
202            arg_count: ArgCount::Variadic,
203            description: "Returns the first non-null value from a list",
204            returns: "ANY",
205            examples: vec![
206                "SELECT COALESCE(NULL, 'default', 'backup')",
207                "SELECT COALESCE(phone, mobile, email) as contact FROM users",
208                "SELECT COALESCE(discount, 0) as discount_amount",
209            ],
210        }
211    }
212
213    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
214        if args.is_empty() {
215            return Err(anyhow!("COALESCE requires at least one argument"));
216        }
217
218        // Return the first non-null value
219        for arg in args {
220            if !matches!(arg, DataValue::Null) {
221                return Ok(arg.clone());
222            }
223        }
224
225        // All values were NULL
226        Ok(DataValue::Null)
227    }
228}
229
230/// IFNULL function - MySQL alias for 2-argument COALESCE
231pub struct IfNullFunction;
232
233impl SqlFunction for IfNullFunction {
234    fn signature(&self) -> FunctionSignature {
235        FunctionSignature {
236            name: "IFNULL",
237            category: FunctionCategory::Mathematical,
238            arg_count: ArgCount::Fixed(2),
239            description: "Returns expr if not NULL, otherwise returns the default value",
240            returns: "ANY",
241            examples: vec![
242                "SELECT IFNULL(NULL, 'default')", // Returns 'default'
243                "SELECT IFNULL(discount, 0) FROM orders",
244                "SELECT IFNULL(phone, mobile) AS contact FROM users",
245            ],
246        }
247    }
248
249    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
250        self.validate_args(args)?;
251        if matches!(args[0], DataValue::Null) {
252            Ok(args[1].clone())
253        } else {
254            Ok(args[0].clone())
255        }
256    }
257}
258
259/// NULLIF function - returns NULL if two values are equal
260pub struct NullIfFunction;
261
262impl SqlFunction for NullIfFunction {
263    fn signature(&self) -> FunctionSignature {
264        FunctionSignature {
265            name: "NULLIF",
266            category: FunctionCategory::Mathematical,
267            arg_count: ArgCount::Fixed(2),
268            description: "Returns NULL if two values are equal, otherwise returns the first value",
269            returns: "ANY",
270            examples: vec![
271                "SELECT NULLIF(0, 0)", // Returns NULL
272                "SELECT NULLIF(price, 0) as non_zero_price",
273                "SELECT NULLIF(status, 'DELETED') as active_status",
274            ],
275        }
276    }
277
278    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
279        self.validate_args(args)?;
280
281        let val1 = &args[0];
282        let val2 = &args[1];
283
284        // Check if values are equal
285        match compare_values(val1, val2) {
286            Some(Ordering::Equal) => Ok(DataValue::Null),
287            Some(_) => Ok(val1.clone()),
288            None => {
289                // Different types - they can't be equal
290                Ok(val1.clone())
291            }
292        }
293    }
294}
295
296/// IIF function - immediate if (if-then-else)
297pub struct IifFunction;
298
299impl SqlFunction for IifFunction {
300    fn signature(&self) -> FunctionSignature {
301        FunctionSignature {
302            name: "IIF",
303            category: FunctionCategory::Mathematical,
304            arg_count: ArgCount::Fixed(3),
305            description: "Returns second argument if first is true, third if false",
306            returns: "ANY",
307            examples: vec![
308                "SELECT IIF(1 > 0, 'positive', 'negative')",
309                "SELECT IIF(MASS_SUN() > MASS_EARTH(), 'sun', 'earth') as bigger",
310                "SELECT IIF(price > 100, 'expensive', 'affordable') as price_category",
311            ],
312        }
313    }
314
315    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
316        self.validate_args(args)?;
317
318        let condition = &args[0];
319        let true_value = &args[1];
320        let false_value = &args[2];
321
322        // Evaluate condition as boolean
323        let is_true = match condition {
324            DataValue::Boolean(b) => *b,
325            DataValue::Integer(i) => *i != 0,
326            DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
327            DataValue::String(s) => !s.is_empty(),
328            DataValue::InternedString(s) => !s.is_empty(),
329            DataValue::Null => false,
330            _ => false,
331        };
332
333        Ok(if is_true {
334            true_value.clone()
335        } else {
336            false_value.clone()
337        })
338    }
339}
340
341/// `GREATEST_LABEL` function - returns the label associated with the greatest value
342/// Takes pairs of (label, value) and returns the label of the maximum value
343pub struct GreatestLabelFunction;
344
345impl SqlFunction for GreatestLabelFunction {
346    fn signature(&self) -> FunctionSignature {
347        FunctionSignature {
348            name: "GREATEST_LABEL",
349            category: FunctionCategory::Mathematical,
350            arg_count: ArgCount::Variadic,
351            description:
352                "Returns the label associated with the greatest value from label/value pairs",
353            returns: "STRING",
354            examples: vec![
355                "SELECT GREATEST_LABEL('earth', MASS_EARTH(), 'sun', MASS_SUN()) as bigger_body",
356                "SELECT GREATEST_LABEL('jan', 100, 'feb', 150, 'mar', 120) as best_month",
357                "SELECT GREATEST_LABEL('product_a', sales_a, 'product_b', sales_b) as top_product",
358            ],
359        }
360    }
361
362    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
363        if args.is_empty() {
364            return Err(anyhow!(
365                "GREATEST_LABEL requires at least one label/value pair"
366            ));
367        }
368
369        if args.len() % 2 != 0 {
370            return Err(anyhow!(
371                "GREATEST_LABEL requires an even number of arguments (label/value pairs)"
372            ));
373        }
374
375        let mut best_label = None;
376        let mut best_value = None;
377
378        // Process pairs of (label, value)
379        for i in (0..args.len()).step_by(2) {
380            let label = &args[i];
381            let value = &args[i + 1];
382
383            // Skip if value is NULL
384            if matches!(value, DataValue::Null) {
385                continue;
386            }
387
388            match &best_value {
389                None => {
390                    // First non-null value
391                    best_label = Some(label.clone());
392                    best_value = Some(value.clone());
393                }
394                Some(current_best) => {
395                    // Compare with current best
396                    match compare_values(value, current_best) {
397                        Some(Ordering::Greater) => {
398                            best_label = Some(label.clone());
399                            best_value = Some(value.clone());
400                        }
401                        Some(_) => {
402                            // Keep current best
403                        }
404                        None => {
405                            // Type mismatch - can't compare
406                            return Err(anyhow!(
407                                "GREATEST_LABEL: Cannot compare values of different types: {:?} and {:?}",
408                                current_best,
409                                value
410                            ));
411                        }
412                    }
413                }
414            }
415        }
416
417        // Return the label of the greatest value, or NULL if all values were NULL
418        Ok(best_label.unwrap_or(DataValue::Null))
419    }
420}
421
422/// `LEAST_LABEL` function - returns the label associated with the smallest value
423pub struct LeastLabelFunction;
424
425impl SqlFunction for LeastLabelFunction {
426    fn signature(&self) -> FunctionSignature {
427        FunctionSignature {
428            name: "LEAST_LABEL",
429            category: FunctionCategory::Mathematical,
430            arg_count: ArgCount::Variadic,
431            description: "Returns the label associated with the smallest value from label/value pairs",
432            returns: "STRING",
433            examples: vec![
434                "SELECT LEAST_LABEL('mercury', MASS_MERCURY(), 'earth', MASS_EARTH()) as smaller_planet",
435                "SELECT LEAST_LABEL('jan', 100, 'feb', 150, 'mar', 120) as worst_month",
436                "SELECT LEAST_LABEL('cost_a', 50, 'cost_b', 30) as cheapest_option",
437            ],
438        }
439    }
440
441    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
442        if args.is_empty() {
443            return Err(anyhow!(
444                "LEAST_LABEL requires at least one label/value pair"
445            ));
446        }
447
448        if args.len() % 2 != 0 {
449            return Err(anyhow!(
450                "LEAST_LABEL requires an even number of arguments (label/value pairs)"
451            ));
452        }
453
454        let mut best_label = None;
455        let mut best_value = None;
456
457        // Process pairs of (label, value)
458        for i in (0..args.len()).step_by(2) {
459            let label = &args[i];
460            let value = &args[i + 1];
461
462            // Skip if value is NULL
463            if matches!(value, DataValue::Null) {
464                continue;
465            }
466
467            match &best_value {
468                None => {
469                    // First non-null value
470                    best_label = Some(label.clone());
471                    best_value = Some(value.clone());
472                }
473                Some(current_best) => {
474                    // Compare with current best
475                    match compare_values(value, current_best) {
476                        Some(Ordering::Less) => {
477                            best_label = Some(label.clone());
478                            best_value = Some(value.clone());
479                        }
480                        Some(_) => {
481                            // Keep current best
482                        }
483                        None => {
484                            // Type mismatch - can't compare
485                            return Err(anyhow!(
486                                "LEAST_LABEL: Cannot compare values of different types: {:?} and {:?}",
487                                current_best,
488                                value
489                            ));
490                        }
491                    }
492                }
493            }
494        }
495
496        // Return the label of the smallest value, or NULL if all values were NULL
497        Ok(best_label.unwrap_or(DataValue::Null))
498    }
499}
500
501/// Register all comparison functions
502pub fn register_comparison_functions(registry: &mut super::FunctionRegistry) {
503    registry.register(Box::new(GreatestFunction));
504    registry.register(Box::new(LeastFunction));
505    registry.register(Box::new(CoalesceFunction));
506    registry.register(Box::new(IfNullFunction));
507    registry.register(Box::new(NullIfFunction));
508    registry.register(Box::new(IifFunction));
509    registry.register(Box::new(GreatestLabelFunction));
510    registry.register(Box::new(LeastLabelFunction));
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_greatest_with_integers() {
519        let func = GreatestFunction;
520        let args = vec![
521            DataValue::Integer(10),
522            DataValue::Integer(5),
523            DataValue::Integer(20),
524            DataValue::Integer(15),
525        ];
526
527        let result = func.evaluate(&args).unwrap();
528        assert_eq!(result, DataValue::Integer(20));
529    }
530
531    #[test]
532    fn test_greatest_with_floats() {
533        let func = GreatestFunction;
534        let args = vec![
535            DataValue::Float(10.5),
536            DataValue::Float(5.2),
537            DataValue::Float(20.8),
538            DataValue::Float(15.3),
539        ];
540
541        let result = func.evaluate(&args).unwrap();
542        assert_eq!(result, DataValue::Float(20.8));
543    }
544
545    #[test]
546    fn test_greatest_with_mixed_numbers() {
547        let func = GreatestFunction;
548        let args = vec![
549            DataValue::Integer(10),
550            DataValue::Float(15.5),
551            DataValue::Integer(20),
552            DataValue::Float(12.3),
553        ];
554
555        let result = func.evaluate(&args).unwrap();
556        assert_eq!(result, DataValue::Integer(20));
557    }
558
559    #[test]
560    fn test_greatest_with_nulls() {
561        let func = GreatestFunction;
562        let args = vec![
563            DataValue::Null,
564            DataValue::Integer(10),
565            DataValue::Null,
566            DataValue::Integer(5),
567        ];
568
569        let result = func.evaluate(&args).unwrap();
570        assert_eq!(result, DataValue::Integer(10));
571    }
572
573    #[test]
574    fn test_greatest_all_nulls() {
575        let func = GreatestFunction;
576        let args = vec![DataValue::Null, DataValue::Null, DataValue::Null];
577
578        let result = func.evaluate(&args).unwrap();
579        assert_eq!(result, DataValue::Null);
580    }
581
582    #[test]
583    fn test_greatest_with_strings() {
584        let func = GreatestFunction;
585        let args = vec![
586            DataValue::String("apple".to_string()),
587            DataValue::String("banana".to_string()),
588            DataValue::String("cherry".to_string()),
589        ];
590
591        let result = func.evaluate(&args).unwrap();
592        assert_eq!(result, DataValue::String("cherry".to_string()));
593    }
594
595    #[test]
596    fn test_least_with_integers() {
597        let func = LeastFunction;
598        let args = vec![
599            DataValue::Integer(10),
600            DataValue::Integer(5),
601            DataValue::Integer(20),
602            DataValue::Integer(15),
603        ];
604
605        let result = func.evaluate(&args).unwrap();
606        assert_eq!(result, DataValue::Integer(5));
607    }
608
609    #[test]
610    fn test_least_with_nulls() {
611        let func = LeastFunction;
612        let args = vec![
613            DataValue::Integer(10),
614            DataValue::Null,
615            DataValue::Integer(5),
616            DataValue::Integer(20),
617        ];
618
619        let result = func.evaluate(&args).unwrap();
620        assert_eq!(result, DataValue::Integer(5));
621    }
622
623    #[test]
624    fn test_coalesce() {
625        let func = CoalesceFunction;
626        let args = vec![
627            DataValue::Null,
628            DataValue::Null,
629            DataValue::String("first".to_string()),
630            DataValue::String("second".to_string()),
631        ];
632
633        let result = func.evaluate(&args).unwrap();
634        assert_eq!(result, DataValue::String("first".to_string()));
635    }
636
637    #[test]
638    fn test_nullif_equal() {
639        let func = NullIfFunction;
640        let args = vec![DataValue::Integer(5), DataValue::Integer(5)];
641
642        let result = func.evaluate(&args).unwrap();
643        assert_eq!(result, DataValue::Null);
644    }
645
646    #[test]
647    fn test_nullif_not_equal() {
648        let func = NullIfFunction;
649        let args = vec![DataValue::Integer(5), DataValue::Integer(10)];
650
651        let result = func.evaluate(&args).unwrap();
652        assert_eq!(result, DataValue::Integer(5));
653    }
654
655    #[test]
656    fn test_ifnull_null_returns_default() {
657        let func = IfNullFunction;
658        let args = vec![DataValue::Null, DataValue::Integer(42)];
659        assert_eq!(func.evaluate(&args).unwrap(), DataValue::Integer(42));
660    }
661
662    #[test]
663    fn test_ifnull_non_null_returns_value() {
664        let func = IfNullFunction;
665        let args = vec![DataValue::Integer(7), DataValue::Integer(42)];
666        assert_eq!(func.evaluate(&args).unwrap(), DataValue::Integer(7));
667    }
668
669    #[test]
670    fn test_ifnull_requires_two_args() {
671        let func = IfNullFunction;
672        assert!(func.evaluate(&[DataValue::Null]).is_err());
673        assert!(func
674            .evaluate(&[DataValue::Null, DataValue::Null, DataValue::Null])
675            .is_err());
676    }
677}