1use crate::data::datatable::DataTable;
5use crate::data::temp_table_registry::TempTableRegistry;
6use crate::sql::parser::ast::TemplateVar;
7use anyhow::{bail, Result};
8use regex::Regex;
9use serde_json::{json, Value};
10use std::sync::Arc;
11
12pub struct TemplateExpander<'a> {
14 temp_tables: &'a TempTableRegistry,
15}
16
17impl<'a> TemplateExpander<'a> {
18 pub fn new(temp_tables: &'a TempTableRegistry) -> Self {
20 Self { temp_tables }
21 }
22
23 pub fn parse_templates(&self, text: &str) -> Result<Vec<TemplateVar>> {
26 let mut vars = Vec::new();
27
28 let re = Regex::new(r"\$\{(#\w+)(?:\[(\d+)\])?(?:\.(\w+))?\}").unwrap();
32
33 for cap in re.captures_iter(text) {
34 let table_name = cap.get(1).unwrap().as_str().to_string();
35 let index = cap.get(2).and_then(|m| m.as_str().parse::<usize>().ok());
36 let column = cap.get(3).map(|m| m.as_str().to_string());
37 let placeholder = cap.get(0).unwrap().as_str().to_string();
38
39 vars.push(TemplateVar {
40 placeholder,
41 table_name,
42 column,
43 index,
44 });
45 }
46
47 Ok(vars)
48 }
49
50 pub fn expand(&self, text: &str, template_vars: &[TemplateVar]) -> Result<String> {
52 let mut result = text.to_string();
53
54 for var in template_vars {
55 let replacement = self.resolve_template_var(var)?;
56 result = result.replace(&var.placeholder, &replacement);
57 }
58
59 Ok(result)
60 }
61
62 fn resolve_template_var(&self, var: &TemplateVar) -> Result<String> {
64 let table = self
66 .temp_tables
67 .get(&var.table_name)
68 .ok_or_else(|| anyhow::anyhow!("Temp table '{}' not found", var.table_name))?;
69
70 if var.column.is_none() && var.index.is_none() {
72 return Ok(self.table_to_json(&table)?);
73 }
74
75 if var.column.is_none() && var.index.is_some() {
77 let index = var.index.unwrap();
78 return Ok(self.row_to_json(&table, index)?);
79 }
80
81 if var.column.is_some() && var.index.is_none() {
83 let column = var.column.as_ref().unwrap();
84 return Ok(self.column_to_json(&table, column)?);
85 }
86
87 if var.column.is_some() && var.index.is_some() {
89 let column = var.column.as_ref().unwrap();
90 let index = var.index.unwrap();
91 return Ok(self.cell_to_json(&table, index, column)?);
92 }
93
94 bail!("Invalid template variable: {}", var.placeholder)
95 }
96
97 fn table_to_json(&self, table: &Arc<DataTable>) -> Result<String> {
99 let mut rows = Vec::new();
100
101 for row_idx in 0..table.row_count() {
102 let mut row_obj = serde_json::Map::new();
103 for (col_idx, col_name) in table.column_names().iter().enumerate() {
104 if let Some(cell_value) = table.get_value(row_idx, col_idx) {
105 let value = self.data_value_to_json(cell_value);
106 row_obj.insert(col_name.clone(), value);
107 }
108 }
109 rows.push(Value::Object(row_obj));
110 }
111
112 Ok(serde_json::to_string(&rows)?)
113 }
114
115 fn row_to_json(&self, table: &Arc<DataTable>, row_idx: usize) -> Result<String> {
117 if row_idx >= table.row_count() {
118 bail!(
119 "Row index {} out of bounds (table has {} rows)",
120 row_idx,
121 table.row_count()
122 );
123 }
124
125 let mut row_obj = serde_json::Map::new();
126 for (col_idx, col_name) in table.column_names().iter().enumerate() {
127 if let Some(cell_value) = table.get_value(row_idx, col_idx) {
128 let value = self.data_value_to_json(cell_value);
129 row_obj.insert(col_name.clone(), value);
130 }
131 }
132
133 Ok(serde_json::to_string(&Value::Object(row_obj))?)
134 }
135
136 fn column_to_json(&self, table: &Arc<DataTable>, column_name: &str) -> Result<String> {
138 let col_idx = table
139 .column_names()
140 .iter()
141 .position(|name| name.eq_ignore_ascii_case(column_name))
142 .ok_or_else(|| anyhow::anyhow!("Column '{}' not found in table", column_name))?;
143
144 let mut values = Vec::new();
145 for row_idx in 0..table.row_count() {
146 if let Some(cell_value) = table.get_value(row_idx, col_idx) {
147 let value = self.data_value_to_json(cell_value);
148 values.push(value);
149 }
150 }
151
152 Ok(serde_json::to_string(&values)?)
153 }
154
155 fn cell_to_json(
157 &self,
158 table: &Arc<DataTable>,
159 row_idx: usize,
160 column_name: &str,
161 ) -> Result<String> {
162 if row_idx >= table.row_count() {
163 bail!(
164 "Row index {} out of bounds (table has {} rows)",
165 row_idx,
166 table.row_count()
167 );
168 }
169
170 let col_idx = table
171 .column_names()
172 .iter()
173 .position(|name| name.eq_ignore_ascii_case(column_name))
174 .ok_or_else(|| anyhow::anyhow!("Column '{}' not found in table", column_name))?;
175
176 let cell_value = table.get_value(row_idx, col_idx).ok_or_else(|| {
177 anyhow::anyhow!("No value found at row {}, column {}", row_idx, col_idx)
178 })?;
179 let value = self.data_value_to_json(cell_value);
180 Ok(serde_json::to_string(&value)?)
181 }
182
183 fn data_value_to_json(&self, data_value: &crate::data::datatable::DataValue) -> Value {
185 use crate::data::datatable::DataValue;
186
187 match data_value {
188 DataValue::Integer(i) => json!(i),
189 DataValue::Float(f) => json!(f),
190 DataValue::String(s) => json!(s),
191 DataValue::InternedString(s) => json!(s.as_str()),
192 DataValue::Boolean(b) => json!(b),
193 DataValue::DateTime(dt) => json!(dt),
194 DataValue::Null => Value::Null,
195 }
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::data::datatable::{DataColumn, DataRow, DataTable, DataType, DataValue};
203
204 fn create_test_table() -> Arc<DataTable> {
205 let mut table = DataTable::new("test_instruments");
206 table.add_column(DataColumn::new("symbol").with_type(DataType::String));
207 table.add_column(DataColumn::new("price").with_type(DataType::Float));
208 table.add_column(DataColumn::new("quantity").with_type(DataType::Integer));
209
210 table
211 .add_row(DataRow::new(vec![
212 DataValue::String("AAPL".to_string()),
213 DataValue::Float(150.5),
214 DataValue::Integer(100),
215 ]))
216 .unwrap();
217 table
218 .add_row(DataRow::new(vec![
219 DataValue::String("GOOGL".to_string()),
220 DataValue::Float(2800.25),
221 DataValue::Integer(50),
222 ]))
223 .unwrap();
224
225 Arc::new(table)
226 }
227
228 #[test]
229 fn test_parse_templates_simple() {
230 let registry = TempTableRegistry::new();
231 let expander = TemplateExpander::new(®istry);
232
233 let text = "SELECT * FROM external WHERE symbols IN ${#instruments}";
234 let vars = expander.parse_templates(text).unwrap();
235
236 assert_eq!(vars.len(), 1);
237 assert_eq!(vars[0].placeholder, "${#instruments}");
238 assert_eq!(vars[0].table_name, "#instruments");
239 assert!(vars[0].column.is_none());
240 assert!(vars[0].index.is_none());
241 }
242
243 #[test]
244 fn test_parse_templates_with_column() {
245 let registry = TempTableRegistry::new();
246 let expander = TemplateExpander::new(®istry);
247
248 let text = "symbols: ${#instruments.symbol}";
249 let vars = expander.parse_templates(text).unwrap();
250
251 assert_eq!(vars.len(), 1);
252 assert_eq!(vars[0].placeholder, "${#instruments.symbol}");
253 assert_eq!(vars[0].table_name, "#instruments");
254 assert_eq!(vars[0].column, Some("symbol".to_string()));
255 assert!(vars[0].index.is_none());
256 }
257
258 #[test]
259 fn test_parse_templates_with_index() {
260 let registry = TempTableRegistry::new();
261 let expander = TemplateExpander::new(®istry);
262
263 let text = "first row: ${#instruments[0]}";
264 let vars = expander.parse_templates(text).unwrap();
265
266 assert_eq!(vars.len(), 1);
267 assert_eq!(vars[0].placeholder, "${#instruments[0]}");
268 assert_eq!(vars[0].table_name, "#instruments");
269 assert!(vars[0].column.is_none());
270 assert_eq!(vars[0].index, Some(0));
271 }
272
273 #[test]
274 fn test_parse_templates_with_index_and_column() {
275 let registry = TempTableRegistry::new();
276 let expander = TemplateExpander::new(®istry);
277
278 let text = "first symbol: ${#instruments[0].symbol}";
279 let vars = expander.parse_templates(text).unwrap();
280
281 assert_eq!(vars.len(), 1);
282 assert_eq!(vars[0].placeholder, "${#instruments[0].symbol}");
283 assert_eq!(vars[0].table_name, "#instruments");
284 assert_eq!(vars[0].column, Some("symbol".to_string()));
285 assert_eq!(vars[0].index, Some(0));
286 }
287
288 #[test]
289 fn test_expand_entire_table() {
290 let mut registry = TempTableRegistry::new();
291 let table = create_test_table();
292 registry.insert("#instruments".to_string(), table).unwrap();
293
294 let expander = TemplateExpander::new(®istry);
295 let text = "Data: ${#instruments}";
296 let vars = expander.parse_templates(text).unwrap();
297 let result = expander.expand(text, &vars).unwrap();
298
299 assert!(result.contains("AAPL"));
300 assert!(result.contains("GOOGL"));
301 assert!(result.contains("150.5"));
302 assert!(result.contains("2800.25"));
303 }
304
305 #[test]
306 fn test_expand_single_column() {
307 let mut registry = TempTableRegistry::new();
308 let table = create_test_table();
309 registry.insert("#instruments".to_string(), table).unwrap();
310
311 let expander = TemplateExpander::new(®istry);
312 let text = "Symbols: ${#instruments.symbol}";
313 let vars = expander.parse_templates(text).unwrap();
314 let result = expander.expand(text, &vars).unwrap();
315
316 assert!(result.contains("AAPL"));
317 assert!(result.contains("GOOGL"));
318 assert!(!result.contains("150.5")); }
320
321 #[test]
322 fn test_expand_single_cell() {
323 let mut registry = TempTableRegistry::new();
324 let table = create_test_table();
325 registry.insert("#instruments".to_string(), table).unwrap();
326
327 let expander = TemplateExpander::new(®istry);
328 let text = "First symbol: ${#instruments[0].symbol}";
329 let vars = expander.parse_templates(text).unwrap();
330 let result = expander.expand(text, &vars).unwrap();
331
332 assert!(result.contains("AAPL"));
333 assert!(!result.contains("GOOGL"));
334 }
335
336 #[test]
337 fn test_expand_multiple_templates() {
338 let mut registry = TempTableRegistry::new();
339 let table = create_test_table();
340 registry.insert("#instruments".to_string(), table).unwrap();
341
342 let expander = TemplateExpander::new(®istry);
343 let text = "First: ${#instruments[0].symbol}, All: ${#instruments.symbol}";
344 let vars = expander.parse_templates(text).unwrap();
345 let result = expander.expand(text, &vars).unwrap();
346
347 assert!(result.contains("AAPL"));
348 assert!(result.contains("GOOGL"));
349 }
350}