sql_cli/sql/
template_expander.rs

1// Template expansion for WEB CTEs - inject temp table data into URLs and request bodies
2// Supports syntax like: ${#table_name}, ${#table.column}, ${#table[0].column}
3
4use crate::data::datatable::DataTable;
5use crate::data::temp_table_registry::TempTableRegistry;
6use crate::sql::parser::ast::TemplateVar;
7use anyhow::{bail, Result};
8use regex::Regex;
9use serde_json::{json, Value};
10use std::sync::Arc;
11
12/// Template expander for injecting temp table data into WEB CTE requests
13pub struct TemplateExpander<'a> {
14    temp_tables: &'a TempTableRegistry,
15}
16
17impl<'a> TemplateExpander<'a> {
18    /// Create a new template expander with access to temp tables
19    pub fn new(temp_tables: &'a TempTableRegistry) -> Self {
20        Self { temp_tables }
21    }
22
23    /// Parse template variables from a string
24    /// Finds patterns like: ${#table}, ${#table.column}, ${#table[0].column}
25    pub fn parse_templates(&self, text: &str) -> Result<Vec<TemplateVar>> {
26        let mut vars = Vec::new();
27
28        // Regex to match ${#table_name}, ${#table.column}, ${#table[0]}, ${#table[0].column}
29        // Pattern breakdown: ${(#\w+)(?:\[(\d+)\])?(?:\.(\w+))?}
30        // Capture groups: 1=#table_name, 2=index, 3=column
31        let re = Regex::new(r"\$\{(#\w+)(?:\[(\d+)\])?(?:\.(\w+))?\}").unwrap();
32
33        for cap in re.captures_iter(text) {
34            let table_name = cap.get(1).unwrap().as_str().to_string();
35            let index = cap.get(2).and_then(|m| m.as_str().parse::<usize>().ok());
36            let column = cap.get(3).map(|m| m.as_str().to_string());
37            let placeholder = cap.get(0).unwrap().as_str().to_string();
38
39            vars.push(TemplateVar {
40                placeholder,
41                table_name,
42                column,
43                index,
44            });
45        }
46
47        Ok(vars)
48    }
49
50    /// Expand templates in a string by replacing placeholders with temp table data
51    pub fn expand(&self, text: &str, template_vars: &[TemplateVar]) -> Result<String> {
52        let mut result = text.to_string();
53
54        for var in template_vars {
55            let replacement = self.resolve_template_var(var)?;
56            result = result.replace(&var.placeholder, &replacement);
57        }
58
59        Ok(result)
60    }
61
62    /// Resolve a single template variable to its JSON representation
63    fn resolve_template_var(&self, var: &TemplateVar) -> Result<String> {
64        // Get the temp table
65        let table = self
66            .temp_tables
67            .get(&var.table_name)
68            .ok_or_else(|| anyhow::anyhow!("Temp table '{}' not found", var.table_name))?;
69
70        // Case 1: ${#table} - entire table as JSON array
71        if var.column.is_none() && var.index.is_none() {
72            return Ok(self.table_to_json(&table)?);
73        }
74
75        // Case 2: ${#table[0]} - single row as JSON object
76        if var.column.is_none() && var.index.is_some() {
77            let index = var.index.unwrap();
78            return Ok(self.row_to_json(&table, index)?);
79        }
80
81        // Case 3: ${#table.column} - array of column values
82        if var.column.is_some() && var.index.is_none() {
83            let column = var.column.as_ref().unwrap();
84            return Ok(self.column_to_json(&table, column)?);
85        }
86
87        // Case 4: ${#table[0].column} - single cell value
88        if var.column.is_some() && var.index.is_some() {
89            let column = var.column.as_ref().unwrap();
90            let index = var.index.unwrap();
91            return Ok(self.cell_to_json(&table, index, column)?);
92        }
93
94        bail!("Invalid template variable: {}", var.placeholder)
95    }
96
97    /// Convert entire table to JSON array
98    fn table_to_json(&self, table: &Arc<DataTable>) -> Result<String> {
99        let mut rows = Vec::new();
100
101        for row_idx in 0..table.row_count() {
102            let mut row_obj = serde_json::Map::new();
103            for (col_idx, col_name) in table.column_names().iter().enumerate() {
104                if let Some(cell_value) = table.get_value(row_idx, col_idx) {
105                    let value = self.data_value_to_json(cell_value);
106                    row_obj.insert(col_name.clone(), value);
107                }
108            }
109            rows.push(Value::Object(row_obj));
110        }
111
112        Ok(serde_json::to_string(&rows)?)
113    }
114
115    /// Convert single row to JSON object
116    fn row_to_json(&self, table: &Arc<DataTable>, row_idx: usize) -> Result<String> {
117        if row_idx >= table.row_count() {
118            bail!(
119                "Row index {} out of bounds (table has {} rows)",
120                row_idx,
121                table.row_count()
122            );
123        }
124
125        let mut row_obj = serde_json::Map::new();
126        for (col_idx, col_name) in table.column_names().iter().enumerate() {
127            if let Some(cell_value) = table.get_value(row_idx, col_idx) {
128                let value = self.data_value_to_json(cell_value);
129                row_obj.insert(col_name.clone(), value);
130            }
131        }
132
133        Ok(serde_json::to_string(&Value::Object(row_obj))?)
134    }
135
136    /// Convert column to JSON array
137    fn column_to_json(&self, table: &Arc<DataTable>, column_name: &str) -> Result<String> {
138        let col_idx = table
139            .column_names()
140            .iter()
141            .position(|name| name.eq_ignore_ascii_case(column_name))
142            .ok_or_else(|| anyhow::anyhow!("Column '{}' not found in table", column_name))?;
143
144        let mut values = Vec::new();
145        for row_idx in 0..table.row_count() {
146            if let Some(cell_value) = table.get_value(row_idx, col_idx) {
147                let value = self.data_value_to_json(cell_value);
148                values.push(value);
149            }
150        }
151
152        Ok(serde_json::to_string(&values)?)
153    }
154
155    /// Convert single cell to JSON value
156    fn cell_to_json(
157        &self,
158        table: &Arc<DataTable>,
159        row_idx: usize,
160        column_name: &str,
161    ) -> Result<String> {
162        if row_idx >= table.row_count() {
163            bail!(
164                "Row index {} out of bounds (table has {} rows)",
165                row_idx,
166                table.row_count()
167            );
168        }
169
170        let col_idx = table
171            .column_names()
172            .iter()
173            .position(|name| name.eq_ignore_ascii_case(column_name))
174            .ok_or_else(|| anyhow::anyhow!("Column '{}' not found in table", column_name))?;
175
176        let cell_value = table.get_value(row_idx, col_idx).ok_or_else(|| {
177            anyhow::anyhow!("No value found at row {}, column {}", row_idx, col_idx)
178        })?;
179        let value = self.data_value_to_json(cell_value);
180        Ok(serde_json::to_string(&value)?)
181    }
182
183    /// Convert DataValue to serde_json::Value
184    fn data_value_to_json(&self, data_value: &crate::data::datatable::DataValue) -> Value {
185        use crate::data::datatable::DataValue;
186
187        match data_value {
188            DataValue::Integer(i) => json!(i),
189            DataValue::Float(f) => json!(f),
190            DataValue::String(s) => json!(s),
191            DataValue::InternedString(s) => json!(s.as_str()),
192            DataValue::Boolean(b) => json!(b),
193            DataValue::DateTime(dt) => json!(dt),
194            DataValue::Vector(v) => json!(v),
195            DataValue::Null => Value::Null,
196        }
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::data::datatable::{DataColumn, DataRow, DataTable, DataType, DataValue};
204
205    fn create_test_table() -> Arc<DataTable> {
206        let mut table = DataTable::new("test_instruments");
207        table.add_column(DataColumn::new("symbol").with_type(DataType::String));
208        table.add_column(DataColumn::new("price").with_type(DataType::Float));
209        table.add_column(DataColumn::new("quantity").with_type(DataType::Integer));
210
211        table
212            .add_row(DataRow::new(vec![
213                DataValue::String("AAPL".to_string()),
214                DataValue::Float(150.5),
215                DataValue::Integer(100),
216            ]))
217            .unwrap();
218        table
219            .add_row(DataRow::new(vec![
220                DataValue::String("GOOGL".to_string()),
221                DataValue::Float(2800.25),
222                DataValue::Integer(50),
223            ]))
224            .unwrap();
225
226        Arc::new(table)
227    }
228
229    #[test]
230    fn test_parse_templates_simple() {
231        let registry = TempTableRegistry::new();
232        let expander = TemplateExpander::new(&registry);
233
234        let text = "SELECT * FROM external WHERE symbols IN ${#instruments}";
235        let vars = expander.parse_templates(text).unwrap();
236
237        assert_eq!(vars.len(), 1);
238        assert_eq!(vars[0].placeholder, "${#instruments}");
239        assert_eq!(vars[0].table_name, "#instruments");
240        assert!(vars[0].column.is_none());
241        assert!(vars[0].index.is_none());
242    }
243
244    #[test]
245    fn test_parse_templates_with_column() {
246        let registry = TempTableRegistry::new();
247        let expander = TemplateExpander::new(&registry);
248
249        let text = "symbols: ${#instruments.symbol}";
250        let vars = expander.parse_templates(text).unwrap();
251
252        assert_eq!(vars.len(), 1);
253        assert_eq!(vars[0].placeholder, "${#instruments.symbol}");
254        assert_eq!(vars[0].table_name, "#instruments");
255        assert_eq!(vars[0].column, Some("symbol".to_string()));
256        assert!(vars[0].index.is_none());
257    }
258
259    #[test]
260    fn test_parse_templates_with_index() {
261        let registry = TempTableRegistry::new();
262        let expander = TemplateExpander::new(&registry);
263
264        let text = "first row: ${#instruments[0]}";
265        let vars = expander.parse_templates(text).unwrap();
266
267        assert_eq!(vars.len(), 1);
268        assert_eq!(vars[0].placeholder, "${#instruments[0]}");
269        assert_eq!(vars[0].table_name, "#instruments");
270        assert!(vars[0].column.is_none());
271        assert_eq!(vars[0].index, Some(0));
272    }
273
274    #[test]
275    fn test_parse_templates_with_index_and_column() {
276        let registry = TempTableRegistry::new();
277        let expander = TemplateExpander::new(&registry);
278
279        let text = "first symbol: ${#instruments[0].symbol}";
280        let vars = expander.parse_templates(text).unwrap();
281
282        assert_eq!(vars.len(), 1);
283        assert_eq!(vars[0].placeholder, "${#instruments[0].symbol}");
284        assert_eq!(vars[0].table_name, "#instruments");
285        assert_eq!(vars[0].column, Some("symbol".to_string()));
286        assert_eq!(vars[0].index, Some(0));
287    }
288
289    #[test]
290    fn test_expand_entire_table() {
291        let mut registry = TempTableRegistry::new();
292        let table = create_test_table();
293        registry.insert("#instruments".to_string(), table).unwrap();
294
295        let expander = TemplateExpander::new(&registry);
296        let text = "Data: ${#instruments}";
297        let vars = expander.parse_templates(text).unwrap();
298        let result = expander.expand(text, &vars).unwrap();
299
300        assert!(result.contains("AAPL"));
301        assert!(result.contains("GOOGL"));
302        assert!(result.contains("150.5"));
303        assert!(result.contains("2800.25"));
304    }
305
306    #[test]
307    fn test_expand_single_column() {
308        let mut registry = TempTableRegistry::new();
309        let table = create_test_table();
310        registry.insert("#instruments".to_string(), table).unwrap();
311
312        let expander = TemplateExpander::new(&registry);
313        let text = "Symbols: ${#instruments.symbol}";
314        let vars = expander.parse_templates(text).unwrap();
315        let result = expander.expand(text, &vars).unwrap();
316
317        assert!(result.contains("AAPL"));
318        assert!(result.contains("GOOGL"));
319        assert!(!result.contains("150.5")); // Price should not be included
320    }
321
322    #[test]
323    fn test_expand_single_cell() {
324        let mut registry = TempTableRegistry::new();
325        let table = create_test_table();
326        registry.insert("#instruments".to_string(), table).unwrap();
327
328        let expander = TemplateExpander::new(&registry);
329        let text = "First symbol: ${#instruments[0].symbol}";
330        let vars = expander.parse_templates(text).unwrap();
331        let result = expander.expand(text, &vars).unwrap();
332
333        assert!(result.contains("AAPL"));
334        assert!(!result.contains("GOOGL"));
335    }
336
337    #[test]
338    fn test_expand_multiple_templates() {
339        let mut registry = TempTableRegistry::new();
340        let table = create_test_table();
341        registry.insert("#instruments".to_string(), table).unwrap();
342
343        let expander = TemplateExpander::new(&registry);
344        let text = "First: ${#instruments[0].symbol}, All: ${#instruments.symbol}";
345        let vars = expander.parse_templates(text).unwrap();
346        let result = expander.expand(text, &vars).unwrap();
347
348        assert!(result.contains("AAPL"));
349        assert!(result.contains("GOOGL"));
350    }
351}