sql_cli/sql/
template_expander.rs

1// Template expansion for WEB CTEs - inject temp table data into URLs and request bodies
2// Supports syntax like: ${#table_name}, ${#table.column}, ${#table[0].column}
3
4use crate::data::datatable::DataTable;
5use crate::data::temp_table_registry::TempTableRegistry;
6use crate::sql::parser::ast::TemplateVar;
7use anyhow::{bail, Result};
8use regex::Regex;
9use serde_json::{json, Value};
10use std::sync::Arc;
11
12/// Template expander for injecting temp table data into WEB CTE requests
13pub struct TemplateExpander<'a> {
14    temp_tables: &'a TempTableRegistry,
15}
16
17impl<'a> TemplateExpander<'a> {
18    /// Create a new template expander with access to temp tables
19    pub fn new(temp_tables: &'a TempTableRegistry) -> Self {
20        Self { temp_tables }
21    }
22
23    /// Parse template variables from a string
24    /// Finds patterns like: ${#table}, ${#table.column}, ${#table[0].column}
25    pub fn parse_templates(&self, text: &str) -> Result<Vec<TemplateVar>> {
26        let mut vars = Vec::new();
27
28        // Regex to match ${#table_name}, ${#table.column}, ${#table[0]}, ${#table[0].column}
29        // Pattern breakdown: ${(#\w+)(?:\[(\d+)\])?(?:\.(\w+))?}
30        // Capture groups: 1=#table_name, 2=index, 3=column
31        let re = Regex::new(r"\$\{(#\w+)(?:\[(\d+)\])?(?:\.(\w+))?\}").unwrap();
32
33        for cap in re.captures_iter(text) {
34            let table_name = cap.get(1).unwrap().as_str().to_string();
35            let index = cap.get(2).and_then(|m| m.as_str().parse::<usize>().ok());
36            let column = cap.get(3).map(|m| m.as_str().to_string());
37            let placeholder = cap.get(0).unwrap().as_str().to_string();
38
39            vars.push(TemplateVar {
40                placeholder,
41                table_name,
42                column,
43                index,
44            });
45        }
46
47        Ok(vars)
48    }
49
50    /// Expand templates in a string by replacing placeholders with temp table data
51    pub fn expand(&self, text: &str, template_vars: &[TemplateVar]) -> Result<String> {
52        let mut result = text.to_string();
53
54        for var in template_vars {
55            let replacement = self.resolve_template_var(var)?;
56            result = result.replace(&var.placeholder, &replacement);
57        }
58
59        Ok(result)
60    }
61
62    /// Resolve a single template variable to its JSON representation
63    fn resolve_template_var(&self, var: &TemplateVar) -> Result<String> {
64        // Get the temp table
65        let table = self
66            .temp_tables
67            .get(&var.table_name)
68            .ok_or_else(|| anyhow::anyhow!("Temp table '{}' not found", var.table_name))?;
69
70        // Case 1: ${#table} - entire table as JSON array
71        if var.column.is_none() && var.index.is_none() {
72            return Ok(self.table_to_json(&table)?);
73        }
74
75        // Case 2: ${#table[0]} - single row as JSON object
76        if var.column.is_none() && var.index.is_some() {
77            let index = var.index.unwrap();
78            return Ok(self.row_to_json(&table, index)?);
79        }
80
81        // Case 3: ${#table.column} - array of column values
82        if var.column.is_some() && var.index.is_none() {
83            let column = var.column.as_ref().unwrap();
84            return Ok(self.column_to_json(&table, column)?);
85        }
86
87        // Case 4: ${#table[0].column} - single cell value
88        if var.column.is_some() && var.index.is_some() {
89            let column = var.column.as_ref().unwrap();
90            let index = var.index.unwrap();
91            return Ok(self.cell_to_json(&table, index, column)?);
92        }
93
94        bail!("Invalid template variable: {}", var.placeholder)
95    }
96
97    /// Convert entire table to JSON array
98    fn table_to_json(&self, table: &Arc<DataTable>) -> Result<String> {
99        let mut rows = Vec::new();
100
101        for row_idx in 0..table.row_count() {
102            let mut row_obj = serde_json::Map::new();
103            for (col_idx, col_name) in table.column_names().iter().enumerate() {
104                if let Some(cell_value) = table.get_value(row_idx, col_idx) {
105                    let value = self.data_value_to_json(cell_value);
106                    row_obj.insert(col_name.clone(), value);
107                }
108            }
109            rows.push(Value::Object(row_obj));
110        }
111
112        Ok(serde_json::to_string(&rows)?)
113    }
114
115    /// Convert single row to JSON object
116    fn row_to_json(&self, table: &Arc<DataTable>, row_idx: usize) -> Result<String> {
117        if row_idx >= table.row_count() {
118            bail!(
119                "Row index {} out of bounds (table has {} rows)",
120                row_idx,
121                table.row_count()
122            );
123        }
124
125        let mut row_obj = serde_json::Map::new();
126        for (col_idx, col_name) in table.column_names().iter().enumerate() {
127            if let Some(cell_value) = table.get_value(row_idx, col_idx) {
128                let value = self.data_value_to_json(cell_value);
129                row_obj.insert(col_name.clone(), value);
130            }
131        }
132
133        Ok(serde_json::to_string(&Value::Object(row_obj))?)
134    }
135
136    /// Convert column to JSON array
137    fn column_to_json(&self, table: &Arc<DataTable>, column_name: &str) -> Result<String> {
138        let col_idx = table
139            .column_names()
140            .iter()
141            .position(|name| name.eq_ignore_ascii_case(column_name))
142            .ok_or_else(|| anyhow::anyhow!("Column '{}' not found in table", column_name))?;
143
144        let mut values = Vec::new();
145        for row_idx in 0..table.row_count() {
146            if let Some(cell_value) = table.get_value(row_idx, col_idx) {
147                let value = self.data_value_to_json(cell_value);
148                values.push(value);
149            }
150        }
151
152        Ok(serde_json::to_string(&values)?)
153    }
154
155    /// Convert single cell to JSON value
156    fn cell_to_json(
157        &self,
158        table: &Arc<DataTable>,
159        row_idx: usize,
160        column_name: &str,
161    ) -> Result<String> {
162        if row_idx >= table.row_count() {
163            bail!(
164                "Row index {} out of bounds (table has {} rows)",
165                row_idx,
166                table.row_count()
167            );
168        }
169
170        let col_idx = table
171            .column_names()
172            .iter()
173            .position(|name| name.eq_ignore_ascii_case(column_name))
174            .ok_or_else(|| anyhow::anyhow!("Column '{}' not found in table", column_name))?;
175
176        let cell_value = table.get_value(row_idx, col_idx).ok_or_else(|| {
177            anyhow::anyhow!("No value found at row {}, column {}", row_idx, col_idx)
178        })?;
179        let value = self.data_value_to_json(cell_value);
180        Ok(serde_json::to_string(&value)?)
181    }
182
183    /// Convert DataValue to serde_json::Value
184    fn data_value_to_json(&self, data_value: &crate::data::datatable::DataValue) -> Value {
185        use crate::data::datatable::DataValue;
186
187        match data_value {
188            DataValue::Integer(i) => json!(i),
189            DataValue::Float(f) => json!(f),
190            DataValue::String(s) => json!(s),
191            DataValue::InternedString(s) => json!(s.as_str()),
192            DataValue::Boolean(b) => json!(b),
193            DataValue::DateTime(dt) => json!(dt),
194            DataValue::Null => Value::Null,
195        }
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::data::datatable::{DataColumn, DataRow, DataTable, DataType, DataValue};
203
204    fn create_test_table() -> Arc<DataTable> {
205        let mut table = DataTable::new("test_instruments");
206        table.add_column(DataColumn::new("symbol").with_type(DataType::String));
207        table.add_column(DataColumn::new("price").with_type(DataType::Float));
208        table.add_column(DataColumn::new("quantity").with_type(DataType::Integer));
209
210        table
211            .add_row(DataRow::new(vec![
212                DataValue::String("AAPL".to_string()),
213                DataValue::Float(150.5),
214                DataValue::Integer(100),
215            ]))
216            .unwrap();
217        table
218            .add_row(DataRow::new(vec![
219                DataValue::String("GOOGL".to_string()),
220                DataValue::Float(2800.25),
221                DataValue::Integer(50),
222            ]))
223            .unwrap();
224
225        Arc::new(table)
226    }
227
228    #[test]
229    fn test_parse_templates_simple() {
230        let registry = TempTableRegistry::new();
231        let expander = TemplateExpander::new(&registry);
232
233        let text = "SELECT * FROM external WHERE symbols IN ${#instruments}";
234        let vars = expander.parse_templates(text).unwrap();
235
236        assert_eq!(vars.len(), 1);
237        assert_eq!(vars[0].placeholder, "${#instruments}");
238        assert_eq!(vars[0].table_name, "#instruments");
239        assert!(vars[0].column.is_none());
240        assert!(vars[0].index.is_none());
241    }
242
243    #[test]
244    fn test_parse_templates_with_column() {
245        let registry = TempTableRegistry::new();
246        let expander = TemplateExpander::new(&registry);
247
248        let text = "symbols: ${#instruments.symbol}";
249        let vars = expander.parse_templates(text).unwrap();
250
251        assert_eq!(vars.len(), 1);
252        assert_eq!(vars[0].placeholder, "${#instruments.symbol}");
253        assert_eq!(vars[0].table_name, "#instruments");
254        assert_eq!(vars[0].column, Some("symbol".to_string()));
255        assert!(vars[0].index.is_none());
256    }
257
258    #[test]
259    fn test_parse_templates_with_index() {
260        let registry = TempTableRegistry::new();
261        let expander = TemplateExpander::new(&registry);
262
263        let text = "first row: ${#instruments[0]}";
264        let vars = expander.parse_templates(text).unwrap();
265
266        assert_eq!(vars.len(), 1);
267        assert_eq!(vars[0].placeholder, "${#instruments[0]}");
268        assert_eq!(vars[0].table_name, "#instruments");
269        assert!(vars[0].column.is_none());
270        assert_eq!(vars[0].index, Some(0));
271    }
272
273    #[test]
274    fn test_parse_templates_with_index_and_column() {
275        let registry = TempTableRegistry::new();
276        let expander = TemplateExpander::new(&registry);
277
278        let text = "first symbol: ${#instruments[0].symbol}";
279        let vars = expander.parse_templates(text).unwrap();
280
281        assert_eq!(vars.len(), 1);
282        assert_eq!(vars[0].placeholder, "${#instruments[0].symbol}");
283        assert_eq!(vars[0].table_name, "#instruments");
284        assert_eq!(vars[0].column, Some("symbol".to_string()));
285        assert_eq!(vars[0].index, Some(0));
286    }
287
288    #[test]
289    fn test_expand_entire_table() {
290        let mut registry = TempTableRegistry::new();
291        let table = create_test_table();
292        registry.insert("#instruments".to_string(), table).unwrap();
293
294        let expander = TemplateExpander::new(&registry);
295        let text = "Data: ${#instruments}";
296        let vars = expander.parse_templates(text).unwrap();
297        let result = expander.expand(text, &vars).unwrap();
298
299        assert!(result.contains("AAPL"));
300        assert!(result.contains("GOOGL"));
301        assert!(result.contains("150.5"));
302        assert!(result.contains("2800.25"));
303    }
304
305    #[test]
306    fn test_expand_single_column() {
307        let mut registry = TempTableRegistry::new();
308        let table = create_test_table();
309        registry.insert("#instruments".to_string(), table).unwrap();
310
311        let expander = TemplateExpander::new(&registry);
312        let text = "Symbols: ${#instruments.symbol}";
313        let vars = expander.parse_templates(text).unwrap();
314        let result = expander.expand(text, &vars).unwrap();
315
316        assert!(result.contains("AAPL"));
317        assert!(result.contains("GOOGL"));
318        assert!(!result.contains("150.5")); // Price should not be included
319    }
320
321    #[test]
322    fn test_expand_single_cell() {
323        let mut registry = TempTableRegistry::new();
324        let table = create_test_table();
325        registry.insert("#instruments".to_string(), table).unwrap();
326
327        let expander = TemplateExpander::new(&registry);
328        let text = "First symbol: ${#instruments[0].symbol}";
329        let vars = expander.parse_templates(text).unwrap();
330        let result = expander.expand(text, &vars).unwrap();
331
332        assert!(result.contains("AAPL"));
333        assert!(!result.contains("GOOGL"));
334    }
335
336    #[test]
337    fn test_expand_multiple_templates() {
338        let mut registry = TempTableRegistry::new();
339        let table = create_test_table();
340        registry.insert("#instruments".to_string(), table).unwrap();
341
342        let expander = TemplateExpander::new(&registry);
343        let text = "First: ${#instruments[0].symbol}, All: ${#instruments.symbol}";
344        let vars = expander.parse_templates(text).unwrap();
345        let result = expander.expand(text, &vars).unwrap();
346
347        assert!(result.contains("AAPL"));
348        assert!(result.contains("GOOGL"));
349    }
350}