sql_cli/sql/functions/
group_num.rs

1use crate::data::datatable::DataValue;
2use crate::sql::functions::{ArgCount, FunctionCategory, FunctionSignature, SqlFunction};
3use anyhow::Result;
4use lazy_static::lazy_static;
5use std::collections::HashMap;
6use std::sync::Mutex;
7
8// Global memoization state for GROUP_NUM function
9// This ensures consistency across the entire query
10lazy_static! {
11    static ref GROUP_NUM_MEMO: Mutex<HashMap<String, HashMap<String, i64>>> =
12        Mutex::new(HashMap::new());
13}
14
15/// GROUP_NUM function - assigns a unique number (starting from 0) to each distinct value
16/// Maintains consistency across the entire query execution
17pub struct GroupNumFunction;
18
19impl GroupNumFunction {
20    pub fn new() -> Self {
21        Self
22    }
23
24    /// Clear all memoization (should be called before each new query)
25    pub fn clear_memoization() {
26        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
27        memo.clear();
28    }
29}
30
31impl SqlFunction for GroupNumFunction {
32    fn signature(&self) -> FunctionSignature {
33        FunctionSignature {
34            name: "GROUP_NUM",
35            category: FunctionCategory::Aggregate,
36            arg_count: ArgCount::Fixed(1),
37            description: "Assigns unique sequential numbers (starting from 0) to distinct values",
38            returns: "Integer - unique number for each distinct value",
39            examples: vec![
40                "SELECT order_id, GROUP_NUM(order_id) as grp_num FROM orders",
41                "SELECT customer, GROUP_NUM(customer) as cust_num FROM sales",
42            ],
43        }
44    }
45
46    fn evaluate(&self, args: &[DataValue]) -> Result<DataValue> {
47        if args.len() != 1 {
48            anyhow::bail!("GROUP_NUM requires exactly 1 argument");
49        }
50
51        // Convert the value to a string for consistent hashing
52        let value_str = match &args[0] {
53            DataValue::Null => return Ok(DataValue::Null),
54            DataValue::String(s) => s.clone(),
55            DataValue::InternedString(s) => s.to_string(),
56            DataValue::Integer(i) => i.to_string(),
57            DataValue::Float(f) => f.to_string(),
58            DataValue::Boolean(b) => b.to_string(),
59            DataValue::DateTime(dt) => dt.to_string(),
60            DataValue::Vector(v) => {
61                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
62                format!("[{}]", components.join(","))
63            }
64        };
65
66        // For now, use a default column identifier
67        // In a full implementation, we'd track which column this is for
68        let column_id = "_default_";
69
70        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
71        let column_map = memo
72            .entry(column_id.to_string())
73            .or_insert_with(HashMap::new);
74
75        // Check if we've seen this value before
76        if let Some(&num) = column_map.get(&value_str) {
77            Ok(DataValue::Integer(num))
78        } else {
79            // Assign a new number
80            let new_num = column_map.len() as i64;
81            column_map.insert(value_str, new_num);
82            Ok(DataValue::Integer(new_num))
83        }
84    }
85}
86
87/// Extended version that can track column context
88pub struct GroupNumWithContext;
89
90impl GroupNumWithContext {
91    pub fn new() -> Self {
92        Self
93    }
94
95    /// Evaluate with column context
96    pub fn evaluate_with_context(&self, value: &DataValue, column_name: &str) -> Result<DataValue> {
97        if matches!(value, DataValue::Null) {
98            return Ok(DataValue::Null);
99        }
100
101        let value_str = match value {
102            DataValue::String(s) => s.clone(),
103            DataValue::InternedString(s) => s.to_string(),
104            DataValue::Integer(i) => i.to_string(),
105            DataValue::Float(f) => f.to_string(),
106            DataValue::Boolean(b) => b.to_string(),
107            DataValue::DateTime(dt) => dt.to_string(),
108            DataValue::Vector(v) => {
109                let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
110                format!("[{}]", components.join(","))
111            }
112            DataValue::Null => unreachable!(),
113        };
114
115        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
116        let column_map = memo
117            .entry(column_name.to_string())
118            .or_insert_with(HashMap::new);
119
120        if let Some(&num) = column_map.get(&value_str) {
121            Ok(DataValue::Integer(num))
122        } else {
123            let new_num = column_map.len() as i64;
124            column_map.insert(value_str, new_num);
125            Ok(DataValue::Integer(new_num))
126        }
127    }
128
129    /// Clear memoization for a specific column
130    pub fn clear_column(&self, column_name: &str) {
131        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
132        memo.remove(column_name);
133    }
134
135    /// Clear all memoization
136    pub fn clear_all(&self) {
137        let mut memo = GROUP_NUM_MEMO.lock().unwrap();
138        memo.clear();
139    }
140}