1use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, CTE};
8
9#[derive(Serialize, Deserialize, Debug)]
11pub struct QueryAnalysis {
12 pub valid: bool,
14 pub query_type: String,
16 pub has_star: bool,
18 pub star_locations: Vec<StarLocation>,
20 pub tables: Vec<String>,
22 pub columns: Vec<String>,
24 pub ctes: Vec<CteAnalysis>,
26 pub from_clause: Option<FromClauseInfo>,
28 pub where_clause: Option<WhereClauseInfo>,
30 pub errors: Vec<String>,
32}
33
34#[derive(Serialize, Deserialize, Debug, Clone)]
36pub struct StarLocation {
37 pub line: usize,
39 pub column: usize,
41 pub context: String,
43}
44
45#[derive(Serialize, Deserialize, Debug, Clone)]
47pub struct CteAnalysis {
48 pub name: String,
50 pub cte_type: String,
52 pub start_line: usize,
54 pub end_line: usize,
56 pub start_offset: usize,
58 pub end_offset: usize,
60 pub has_star: bool,
62 pub columns: Vec<String>,
64 pub web_config: Option<WebCteConfig>,
66}
67
68#[derive(Serialize, Deserialize, Debug, Clone)]
70pub struct WebCteConfig {
71 pub url: String,
73 pub method: String,
75 pub headers: Vec<(String, String)>,
77 pub format: Option<String>,
79}
80
81#[derive(Serialize, Deserialize, Debug, Clone)]
83pub struct FromClauseInfo {
84 pub source_type: String,
86 pub name: Option<String>,
88}
89
90#[derive(Serialize, Deserialize, Debug, Clone)]
92pub struct WhereClauseInfo {
93 pub present: bool,
95 pub columns_referenced: Vec<String>,
97}
98
99#[derive(Serialize, Deserialize, Debug)]
101pub struct ColumnExpansion {
102 pub original_query: String,
104 pub expanded_query: String,
106 pub columns: Vec<ColumnInfo>,
108 pub expansion_count: usize,
110 pub cte_columns: HashMap<String, Vec<String>>,
112}
113
114#[derive(Serialize, Deserialize, Debug, Clone)]
116pub struct ColumnInfo {
117 pub name: String,
119 pub data_type: String,
121}
122
123#[derive(Serialize, Deserialize, Debug)]
125pub struct QueryContext {
126 pub context_type: String,
128 pub cte_name: Option<String>,
130 pub cte_index: Option<usize>,
132 pub query_bounds: QueryBounds,
134 pub parent_query_bounds: Option<QueryBounds>,
136 pub can_execute_independently: bool,
138}
139
140#[derive(Serialize, Deserialize, Debug, Clone)]
142pub struct QueryBounds {
143 pub start_line: usize,
145 pub end_line: usize,
147 pub start_offset: usize,
149 pub end_offset: usize,
151}
152
153pub fn analyze_query(ast: &SelectStatement, _sql: &str) -> QueryAnalysis {
155 let mut analysis = QueryAnalysis {
156 valid: true,
157 query_type: "SELECT".to_string(),
158 has_star: false,
159 star_locations: vec![],
160 tables: vec![],
161 columns: vec![],
162 ctes: vec![],
163 from_clause: None,
164 where_clause: None,
165 errors: vec![],
166 };
167
168 for cte in &ast.ctes {
170 analysis.ctes.push(analyze_cte(cte));
171 }
172
173 for item in &ast.select_items {
175 if matches!(item, SelectItem::Star) {
176 analysis.has_star = true;
177 analysis.star_locations.push(StarLocation {
178 line: 1, column: 8,
180 context: "main_query".to_string(),
181 });
182 }
183 }
184
185 if let Some(ref table) = ast.from_table {
187 let table_name: String = table.clone();
188 analysis.tables.push(table_name.clone());
189 analysis.from_clause = Some(FromClauseInfo {
190 source_type: "table".to_string(),
191 name: Some(table_name),
192 });
193 } else if ast.from_subquery.is_some() {
194 analysis.from_clause = Some(FromClauseInfo {
195 source_type: "subquery".to_string(),
196 name: None,
197 });
198 }
199
200 if let Some(ref where_clause) = ast.where_clause {
202 let mut columns = vec![];
203 for condition in &where_clause.conditions {
205 if let Some(col) = extract_column_from_expr(&condition.expr) {
208 if !columns.contains(&col) {
209 columns.push(col);
210 }
211 }
212 }
213
214 analysis.where_clause = Some(WhereClauseInfo {
215 present: true,
216 columns_referenced: columns,
217 });
218 }
219
220 for item in &ast.select_items {
222 if let SelectItem::Column(col_ref) = item {
223 if !analysis.columns.contains(&col_ref.name) {
224 analysis.columns.push(col_ref.name.clone());
225 }
226 }
227 }
228
229 analysis
230}
231
232fn analyze_cte(cte: &CTE) -> CteAnalysis {
233 let cte_type_str = match &cte.cte_type {
234 CTEType::Standard(_) => "Standard",
235 CTEType::Web(_) => "WEB",
236 };
237
238 let mut has_star = false;
239 let mut web_config = None;
240
241 match &cte.cte_type {
242 CTEType::Standard(stmt) => {
243 for item in &stmt.select_items {
245 if matches!(item, SelectItem::Star) {
246 has_star = true;
247 break;
248 }
249 }
250 }
251 CTEType::Web(web_spec) => {
252 let method_str = match &web_spec.method {
253 Some(m) => format!("{:?}", m),
254 None => "GET".to_string(),
255 };
256 web_config = Some(WebCteConfig {
257 url: web_spec.url.clone(),
258 method: method_str,
259 headers: web_spec.headers.clone(),
260 format: web_spec.format.as_ref().map(|f| format!("{:?}", f)),
261 });
262 }
263 }
264
265 CteAnalysis {
266 name: cte.name.clone(),
267 cte_type: cte_type_str.to_string(),
268 start_line: 1, end_line: 1, start_offset: 0,
271 end_offset: 0,
272 has_star,
273 columns: vec![], web_config,
275 }
276}
277
278fn extract_column_from_expr(expr: &crate::sql::parser::ast::SqlExpression) -> Option<String> {
279 use crate::sql::parser::ast::SqlExpression;
280
281 match expr {
282 SqlExpression::Column(col_ref) => Some(col_ref.name.clone()),
283 SqlExpression::BinaryOp { left, right, .. } => {
284 extract_column_from_expr(left).or_else(|| extract_column_from_expr(right))
286 }
287 SqlExpression::FunctionCall { args, .. } => {
288 args.first().and_then(|arg| extract_column_from_expr(arg))
290 }
291 _ => None,
292 }
293}
294
295pub fn extract_cte(ast: &SelectStatement, cte_name: &str) -> Option<String> {
299 let mut target_index = None;
301 for (idx, cte) in ast.ctes.iter().enumerate() {
302 if cte.name == cte_name {
303 target_index = Some(idx);
304 break;
305 }
306 }
307
308 let target_index = target_index?;
309
310 let mut parts = vec![];
312
313 parts.push("WITH".to_string());
315
316 for (idx, cte) in ast.ctes.iter().enumerate() {
317 if idx > target_index {
318 break; }
320
321 let prefix = if idx == 0 { "" } else { "," };
323
324 match &cte.cte_type {
325 CTEType::Standard(stmt) => {
326 parts.push(format!("{} {} AS (", prefix, cte.name));
327 parts.push(indent_query(&format_select_statement(stmt), 2));
328 parts.push(")".to_string());
329 }
330 CTEType::Web(web_spec) => {
331 parts.push(format!("{} WEB {} AS (", prefix, cte.name));
332 parts.push(format!(" URL '{}'", web_spec.url));
333
334 if let Some(ref m) = web_spec.method {
335 parts.push(format!(" METHOD {:?}", m));
336 }
337
338 if let Some(ref f) = web_spec.format {
339 parts.push(format!(" FORMAT {:?}", f));
340 }
341
342 if let Some(cache) = web_spec.cache_seconds {
343 parts.push(format!(" CACHE {}", cache));
344 }
345
346 if !web_spec.headers.is_empty() {
347 parts.push(" HEADERS (".to_string());
348 for (i, (k, v)) in web_spec.headers.iter().enumerate() {
349 let comma = if i < web_spec.headers.len() - 1 {
350 ","
351 } else {
352 ""
353 };
354 parts.push(format!(" '{}': '{}'{}", k, v, comma));
355 }
356 parts.push(" )".to_string());
357 }
358
359 for (field_name, file_path) in &web_spec.form_files {
361 parts.push(format!(" FORM_FILE '{}' '{}'", field_name, file_path));
362 }
363
364 for (field_name, value) in &web_spec.form_fields {
366 let trimmed_value = value.trim();
367 if (trimmed_value.starts_with('{') && trimmed_value.ends_with('}'))
369 || (trimmed_value.starts_with('[') && trimmed_value.ends_with(']'))
370 {
371 parts.push(format!(
373 " FORM_FIELD '{}' $JSON${}$JSON$",
374 field_name, trimmed_value
375 ));
376 } else {
377 parts.push(format!(" FORM_FIELD '{}' '{}'", field_name, value));
379 }
380 }
381
382 if let Some(ref b) = web_spec.body {
383 let trimmed_body = b.trim();
385 if (trimmed_body.starts_with('{') && trimmed_body.ends_with('}'))
386 || (trimmed_body.starts_with('[') && trimmed_body.ends_with(']'))
387 {
388 parts.push(format!(" BODY $JSON${}$JSON$", trimmed_body));
389 } else {
390 parts.push(format!(" BODY '{}'", b));
391 }
392 }
393
394 if let Some(ref jp) = web_spec.json_path {
395 parts.push(format!(" JSON_PATH '{}'", jp));
396 }
397
398 parts.push(")".to_string());
399 }
400 }
401 }
402
403 parts.push(format!("SELECT * FROM {}", cte_name));
405
406 Some(parts.join("\n"))
407}
408
409fn indent_query(query: &str, spaces: usize) -> String {
410 let indent = " ".repeat(spaces);
411 query
412 .lines()
413 .map(|line| format!("{}{}", indent, line))
414 .collect::<Vec<_>>()
415 .join("\n")
416}
417
418fn format_cte_as_query(cte: &CTE) -> String {
419 match &cte.cte_type {
420 CTEType::Standard(stmt) => {
421 format_select_statement(stmt)
424 }
425 CTEType::Web(web_spec) => {
426 let mut parts = vec![
428 format!("WITH WEB {} AS (", cte.name),
429 format!(" URL '{}'", web_spec.url),
430 ];
431
432 if let Some(ref m) = web_spec.method {
433 parts.push(format!(" METHOD {:?}", m));
434 }
435
436 if !web_spec.headers.is_empty() {
437 parts.push(" HEADERS (".to_string());
438 for (k, v) in &web_spec.headers {
439 parts.push(format!(" '{}' = '{}'", k, v));
440 }
441 parts.push(" )".to_string());
442 }
443
444 if let Some(ref b) = web_spec.body {
445 parts.push(format!(" BODY '{}'", b));
446 }
447
448 if let Some(ref f) = web_spec.format {
449 parts.push(format!(" FORMAT {:?}", f));
450 }
451
452 parts.push(")".to_string());
453 parts.push(format!("SELECT * FROM {}", cte.name));
454
455 parts.join("\n")
456 }
457 }
458}
459
460fn format_select_statement(stmt: &SelectStatement) -> String {
461 let mut parts = vec!["SELECT".to_string()];
462
463 if stmt.select_items.is_empty() {
465 parts.push(" *".to_string());
466 } else {
467 for (i, item) in stmt.select_items.iter().enumerate() {
468 let prefix = if i == 0 { " " } else { " , " };
469 match item {
470 SelectItem::Star => parts.push(format!("{}*", prefix)),
471 SelectItem::Column(col) => {
472 parts.push(format!("{}{}", prefix, col.name));
473 }
474 SelectItem::Expression { expr, alias } => {
475 let expr_str = format_expr(expr);
476 parts.push(format!("{}{} AS {}", prefix, expr_str, alias));
477 }
478 }
479 }
480 }
481
482 if let Some(ref table) = stmt.from_table {
484 parts.push(format!("FROM {}", table));
485 }
486
487 if let Some(ref where_clause) = stmt.where_clause {
489 parts.push("WHERE".to_string());
490 for (i, condition) in where_clause.conditions.iter().enumerate() {
491 let connector = if i > 0 {
492 condition
493 .connector
494 .as_ref()
495 .map(|op| match op {
496 crate::sql::parser::ast::LogicalOp::And => "AND",
497 crate::sql::parser::ast::LogicalOp::Or => "OR",
498 })
499 .unwrap_or("AND")
500 } else {
501 ""
502 };
503 let expr_str = format_expr(&condition.expr);
504 if i == 0 {
505 parts.push(format!(" {}", expr_str));
506 } else {
507 parts.push(format!(" {} {}", connector, expr_str));
508 }
509 }
510 }
511
512 if let Some(limit) = stmt.limit {
514 parts.push(format!("LIMIT {}", limit));
515 }
516
517 parts.join("\n")
518}
519
520fn format_expr(expr: &crate::sql::parser::ast::SqlExpression) -> String {
523 crate::sql::parser::ast_formatter::format_expression(expr)
524}
525
526pub fn find_query_context(ast: &SelectStatement, line: usize, _column: usize) -> QueryContext {
528 for (idx, cte) in ast.ctes.iter().enumerate() {
530 let cte_start = 1 + (idx * 5);
533 let cte_end = cte_start + 4;
534
535 if line >= cte_start && line <= cte_end {
536 return QueryContext {
537 context_type: "CTE".to_string(),
538 cte_name: Some(cte.name.clone()),
539 cte_index: Some(idx),
540 query_bounds: QueryBounds {
541 start_line: cte_start,
542 end_line: cte_end,
543 start_offset: 0,
544 end_offset: 0,
545 },
546 parent_query_bounds: Some(QueryBounds {
547 start_line: 1,
548 end_line: 100, start_offset: 0,
550 end_offset: 0,
551 }),
552 can_execute_independently: !matches!(cte.cte_type, CTEType::Web(_)),
553 };
554 }
555 }
556
557 QueryContext {
559 context_type: "main_query".to_string(),
560 cte_name: None,
561 cte_index: None,
562 query_bounds: QueryBounds {
563 start_line: 1,
564 end_line: 100, start_offset: 0,
566 end_offset: 0,
567 },
568 parent_query_bounds: None,
569 can_execute_independently: true,
570 }
571}
572
573#[cfg(test)]
574mod tests {
575 use super::*;
576 use crate::sql::recursive_parser::Parser;
577
578 #[test]
579 fn test_analyze_simple_query() {
580 let sql = "SELECT * FROM trades WHERE price > 100";
581 let mut parser = Parser::new(sql);
582 let ast = parser.parse().unwrap();
583
584 let analysis = analyze_query(&ast, sql);
585
586 assert!(analysis.valid);
587 assert_eq!(analysis.query_type, "SELECT");
588 assert!(analysis.has_star);
589 assert_eq!(analysis.star_locations.len(), 1);
590 assert_eq!(analysis.tables, vec!["trades"]);
591 }
592
593 #[test]
594 fn test_analyze_cte_query() {
595 let sql = "WITH trades AS (SELECT * FROM raw_trades) SELECT symbol FROM trades";
596 let mut parser = Parser::new(sql);
597 let ast = parser.parse().unwrap();
598
599 let analysis = analyze_query(&ast, sql);
600
601 assert!(analysis.valid);
602 assert_eq!(analysis.ctes.len(), 1);
603 assert_eq!(analysis.ctes[0].name, "trades");
604 assert_eq!(analysis.ctes[0].cte_type, "Standard");
605 assert!(analysis.ctes[0].has_star);
606 }
607
608 #[test]
609 fn test_extract_cte() {
610 let sql =
611 "WITH trades AS (SELECT * FROM raw_trades WHERE price > 100) SELECT * FROM trades";
612 let mut parser = Parser::new(sql);
613 let ast = parser.parse().unwrap();
614
615 let extracted = extract_cte(&ast, "trades").unwrap();
616
617 assert!(extracted.contains("SELECT"));
618 assert!(extracted.contains("raw_trades"));
619 assert!(extracted.contains("price > 100"));
620 }
621}