sql_cli/sql/functions/
group_num.rs

1use crate::data::datatable::DataValue;
2use crate::sql::functions::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
3use anyhow::Result;
4use lazy_static::lazy_static;
5use std::collections::HashMap;
6use std::sync::Mutex;
7
8// Global memoization state for GROUP_NUM function
9// This ensures consistency across the entire query
10lazy_static! {
11    static ref GROUP_NUM_MEMO: Mutex<HashMap<String, HashMap<String, i64>>> =
12        Mutex::new(HashMap::new());
13}
14
15/// GROUP_NUM function - assigns a unique number (starting from 0) to each distinct value
16/// Maintains consistency across the entire query execution
17pub struct GroupNumFunction;
18
19impl GroupNumFunction {
20    pub fn new() -> Self {
21        Self
22    }
23
24    /// Clear all memoization (should be called before each new query)
25    pub fn clear_memoization() {
26        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
27        memo.clear();
28    }
29}
30
31impl SqlFunction for GroupNumFunction {
32    fn signature(&self) -> FunctionSignature {
33        FunctionSignature {
34            name: "GROUP_NUM",
35            category: FunctionCategory::Aggregate,
36            arg_count: ArgCount::Fixed(1),
37            description: "Assigns unique sequential numbers (starting from 0) to distinct values",
38            returns: "Integer - unique number for each distinct value",
39            examples: vec![
40                "SELECT order_id, GROUP_NUM(order_id) as grp_num FROM orders",
41                "SELECT customer, GROUP_NUM(customer) as cust_num FROM sales",
42            ],
43        }
44    }
45
46    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
47        if args.len() != 1 {
48            anyhow::bail!("GROUP_NUM requires exactly 1 argument");
49        }
50
51        // Convert the value to a string for consistent hashing
52        let value_str = match &args[0] {
53            DataValue::Null => return Ok(DataValue::Null),
54            DataValue::String(s) => s.clone(),
55            DataValue::InternedString(s) => s.to_string(),
56            DataValue::Integer(i) => i.to_string(),
57            DataValue::Float(f) => f.to_string(),
58            DataValue::Boolean(b) => b.to_string(),
59            DataValue::DateTime(dt) => dt.to_string(),
60        };
61
62        // For now, use a default column identifier
63        // In a full implementation, we'd track which column this is for
64        let column_id = "_default_";
65
66        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
67        let column_map = memo
68            .entry(column_id.to_string())
69            .or_insert_with(HashMap::new);
70
71        // Check if we've seen this value before
72        if let Some(&num) = column_map.get(&value_str) {
73            Ok(DataValue::Integer(num))
74        } else {
75            // Assign a new number
76            let new_num = column_map.len() as i64;
77            column_map.insert(value_str, new_num);
78            Ok(DataValue::Integer(new_num))
79        }
80    }
81}
82
83/// Extended version that can track column context
84pub struct GroupNumWithContext;
85
86impl GroupNumWithContext {
87    pub fn new() -> Self {
88        Self
89    }
90
91    /// Evaluate with column context
92    pub fn evaluate_with_context(&self, value: &DataValue, column_name: &str) -> Result<DataValue> {
93        if matches!(value, DataValue::Null) {
94            return Ok(DataValue::Null);
95        }
96
97        let value_str = match value {
98            DataValue::String(s) => s.clone(),
99            DataValue::InternedString(s) => s.to_string(),
100            DataValue::Integer(i) => i.to_string(),
101            DataValue::Float(f) => f.to_string(),
102            DataValue::Boolean(b) => b.to_string(),
103            DataValue::DateTime(dt) => dt.to_string(),
104            DataValue::Null => unreachable!(),
105        };
106
107        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
108        let column_map = memo
109            .entry(column_name.to_string())
110            .or_insert_with(HashMap::new);
111
112        if let Some(&num) = column_map.get(&value_str) {
113            Ok(DataValue::Integer(num))
114        } else {
115            let new_num = column_map.len() as i64;
116            column_map.insert(value_str, new_num);
117            Ok(DataValue::Integer(new_num))
118        }
119    }
120
121    /// Clear memoization for a specific column
122    pub fn clear_column(&self, column_name: &str) {
123        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
124        memo.remove(column_name);
125    }
126
127    /// Clear all memoization
128    pub fn clear_all(&self) {
129        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
130        memo.clear();
131    }
132}