1pub mod statement_dependencies;
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, CTE};
10
11#[derive(Serialize, Deserialize, Debug)]
13pub struct QueryAnalysis {
14 pub valid: bool,
16 pub query_type: String,
18 pub has_star: bool,
20 pub star_locations: Vec<StarLocation>,
22 pub tables: Vec<String>,
24 pub columns: Vec<String>,
26 pub ctes: Vec<CteAnalysis>,
28 pub from_clause: Option<FromClauseInfo>,
30 pub where_clause: Option<WhereClauseInfo>,
32 pub errors: Vec<String>,
34}
35
36#[derive(Serialize, Deserialize, Debug, Clone)]
38pub struct StarLocation {
39 pub line: usize,
41 pub column: usize,
43 pub context: String,
45}
46
47#[derive(Serialize, Deserialize, Debug, Clone)]
49pub struct CteAnalysis {
50 pub name: String,
52 pub cte_type: String,
54 pub start_line: usize,
56 pub end_line: usize,
58 pub start_offset: usize,
60 pub end_offset: usize,
62 pub has_star: bool,
64 pub columns: Vec<String>,
66 pub web_config: Option<WebCteConfig>,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone)]
72pub struct WebCteConfig {
73 pub url: String,
75 pub method: String,
77 pub headers: Vec<(String, String)>,
79 pub format: Option<String>,
81}
82
83#[derive(Serialize, Deserialize, Debug, Clone)]
85pub struct FromClauseInfo {
86 pub source_type: String,
88 pub name: Option<String>,
90}
91
92#[derive(Serialize, Deserialize, Debug, Clone)]
94pub struct WhereClauseInfo {
95 pub present: bool,
97 pub columns_referenced: Vec<String>,
99}
100
101#[derive(Serialize, Deserialize, Debug)]
103pub struct ColumnExpansion {
104 pub original_query: String,
106 pub expanded_query: String,
108 pub columns: Vec<ColumnInfo>,
110 pub expansion_count: usize,
112 pub cte_columns: HashMap<String, Vec<String>>,
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
118pub struct ColumnInfo {
119 pub name: String,
121 pub data_type: String,
123}
124
125#[derive(Serialize, Deserialize, Debug)]
127pub struct QueryContext {
128 pub context_type: String,
130 pub cte_name: Option<String>,
132 pub cte_index: Option<usize>,
134 pub query_bounds: QueryBounds,
136 pub parent_query_bounds: Option<QueryBounds>,
138 pub can_execute_independently: bool,
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone)]
144pub struct QueryBounds {
145 pub start_line: usize,
147 pub end_line: usize,
149 pub start_offset: usize,
151 pub end_offset: usize,
153}
154
155pub fn analyze_query(ast: &SelectStatement, _sql: &str) -> QueryAnalysis {
157 let mut analysis = QueryAnalysis {
158 valid: true,
159 query_type: "SELECT".to_string(),
160 has_star: false,
161 star_locations: vec![],
162 tables: vec![],
163 columns: vec![],
164 ctes: vec![],
165 from_clause: None,
166 where_clause: None,
167 errors: vec![],
168 };
169
170 for cte in &ast.ctes {
172 analysis.ctes.push(analyze_cte(cte));
173 }
174
175 for item in &ast.select_items {
177 if matches!(item, SelectItem::Star { .. }) {
178 analysis.has_star = true;
179 analysis.star_locations.push(StarLocation {
180 line: 1, column: 8,
182 context: "main_query".to_string(),
183 });
184 }
185 }
186
187 if let Some(ref table) = ast.from_table {
189 let table_name: String = table.clone();
190 analysis.tables.push(table_name.clone());
191 analysis.from_clause = Some(FromClauseInfo {
192 source_type: "table".to_string(),
193 name: Some(table_name),
194 });
195 } else if ast.from_subquery.is_some() {
196 analysis.from_clause = Some(FromClauseInfo {
197 source_type: "subquery".to_string(),
198 name: None,
199 });
200 }
201
202 if let Some(ref where_clause) = ast.where_clause {
204 let mut columns = vec![];
205 for condition in &where_clause.conditions {
207 if let Some(col) = extract_column_from_expr(&condition.expr) {
210 if !columns.contains(&col) {
211 columns.push(col);
212 }
213 }
214 }
215
216 analysis.where_clause = Some(WhereClauseInfo {
217 present: true,
218 columns_referenced: columns,
219 });
220 }
221
222 for item in &ast.select_items {
224 if let SelectItem::Column {
225 column: col_ref, ..
226 } = item
227 {
228 if !analysis.columns.contains(&col_ref.name) {
229 analysis.columns.push(col_ref.name.clone());
230 }
231 }
232 }
233
234 analysis
235}
236
237fn analyze_cte(cte: &CTE) -> CteAnalysis {
238 let cte_type_str = match &cte.cte_type {
239 CTEType::Standard(_) => "Standard",
240 CTEType::Web(_) => "WEB",
241 };
242
243 let mut has_star = false;
244 let mut web_config = None;
245
246 match &cte.cte_type {
247 CTEType::Standard(stmt) => {
248 for item in &stmt.select_items {
250 if matches!(item, SelectItem::Star { .. }) {
251 has_star = true;
252 break;
253 }
254 }
255 }
256 CTEType::Web(web_spec) => {
257 let method_str = match &web_spec.method {
258 Some(m) => format!("{:?}", m),
259 None => "GET".to_string(),
260 };
261 web_config = Some(WebCteConfig {
262 url: web_spec.url.clone(),
263 method: method_str,
264 headers: web_spec.headers.clone(),
265 format: web_spec.format.as_ref().map(|f| format!("{:?}", f)),
266 });
267 }
268 }
269
270 CteAnalysis {
271 name: cte.name.clone(),
272 cte_type: cte_type_str.to_string(),
273 start_line: 1, end_line: 1, start_offset: 0,
276 end_offset: 0,
277 has_star,
278 columns: vec![], web_config,
280 }
281}
282
283fn extract_column_from_expr(expr: &crate::sql::parser::ast::SqlExpression) -> Option<String> {
284 use crate::sql::parser::ast::SqlExpression;
285
286 match expr {
287 SqlExpression::Column(col_ref) => Some(col_ref.name.clone()),
288 SqlExpression::BinaryOp { left, right, .. } => {
289 extract_column_from_expr(left).or_else(|| extract_column_from_expr(right))
291 }
292 SqlExpression::FunctionCall { args, .. } => {
293 args.first().and_then(|arg| extract_column_from_expr(arg))
295 }
296 _ => None,
297 }
298}
299
300pub fn extract_cte(ast: &SelectStatement, cte_name: &str) -> Option<String> {
304 let mut target_index = None;
306 for (idx, cte) in ast.ctes.iter().enumerate() {
307 if cte.name == cte_name {
308 target_index = Some(idx);
309 break;
310 }
311 }
312
313 let target_index = target_index?;
314
315 let mut parts = vec![];
317
318 parts.push("WITH".to_string());
320
321 for (idx, cte) in ast.ctes.iter().enumerate() {
322 if idx > target_index {
323 break; }
325
326 let prefix = if idx == 0 { "" } else { "," };
328
329 match &cte.cte_type {
330 CTEType::Standard(stmt) => {
331 parts.push(format!("{} {} AS (", prefix, cte.name));
332 parts.push(indent_query(&format_select_statement(stmt), 2));
333 parts.push(")".to_string());
334 }
335 CTEType::Web(web_spec) => {
336 parts.push(format!("{} WEB {} AS (", prefix, cte.name));
337 parts.push(format!(" URL '{}'", web_spec.url));
338
339 if let Some(ref m) = web_spec.method {
340 parts.push(format!(" METHOD {:?}", m));
341 }
342
343 if let Some(ref f) = web_spec.format {
344 parts.push(format!(" FORMAT {:?}", f));
345 }
346
347 if let Some(cache) = web_spec.cache_seconds {
348 parts.push(format!(" CACHE {}", cache));
349 }
350
351 if !web_spec.headers.is_empty() {
352 parts.push(" HEADERS (".to_string());
353 for (i, (k, v)) in web_spec.headers.iter().enumerate() {
354 let comma = if i < web_spec.headers.len() - 1 {
355 ","
356 } else {
357 ""
358 };
359 parts.push(format!(" '{}': '{}'{}", k, v, comma));
360 }
361 parts.push(" )".to_string());
362 }
363
364 for (field_name, file_path) in &web_spec.form_files {
366 parts.push(format!(" FORM_FILE '{}' '{}'", field_name, file_path));
367 }
368
369 for (field_name, value) in &web_spec.form_fields {
371 let trimmed_value = value.trim();
372 if (trimmed_value.starts_with('{') && trimmed_value.ends_with('}'))
374 || (trimmed_value.starts_with('[') && trimmed_value.ends_with(']'))
375 {
376 parts.push(format!(
378 " FORM_FIELD '{}' $JSON${}$JSON$",
379 field_name, trimmed_value
380 ));
381 } else {
382 parts.push(format!(" FORM_FIELD '{}' '{}'", field_name, value));
384 }
385 }
386
387 if let Some(ref b) = web_spec.body {
388 let trimmed_body = b.trim();
390 if (trimmed_body.starts_with('{') && trimmed_body.ends_with('}'))
391 || (trimmed_body.starts_with('[') && trimmed_body.ends_with(']'))
392 {
393 parts.push(format!(" BODY $JSON${}$JSON$", trimmed_body));
394 } else {
395 parts.push(format!(" BODY '{}'", b));
396 }
397 }
398
399 if let Some(ref jp) = web_spec.json_path {
400 parts.push(format!(" JSON_PATH '{}'", jp));
401 }
402
403 parts.push(")".to_string());
404 }
405 }
406 }
407
408 parts.push(format!("SELECT * FROM {}", cte_name));
410
411 Some(parts.join("\n"))
412}
413
414fn indent_query(query: &str, spaces: usize) -> String {
415 let indent = " ".repeat(spaces);
416 query
417 .lines()
418 .map(|line| format!("{}{}", indent, line))
419 .collect::<Vec<_>>()
420 .join("\n")
421}
422
423fn format_cte_as_query(cte: &CTE) -> String {
424 match &cte.cte_type {
425 CTEType::Standard(stmt) => {
426 format_select_statement(stmt)
429 }
430 CTEType::Web(web_spec) => {
431 let mut parts = vec![
433 format!("WITH WEB {} AS (", cte.name),
434 format!(" URL '{}'", web_spec.url),
435 ];
436
437 if let Some(ref m) = web_spec.method {
438 parts.push(format!(" METHOD {:?}", m));
439 }
440
441 if !web_spec.headers.is_empty() {
442 parts.push(" HEADERS (".to_string());
443 for (k, v) in &web_spec.headers {
444 parts.push(format!(" '{}' = '{}'", k, v));
445 }
446 parts.push(" )".to_string());
447 }
448
449 if let Some(ref b) = web_spec.body {
450 parts.push(format!(" BODY '{}'", b));
451 }
452
453 if let Some(ref f) = web_spec.format {
454 parts.push(format!(" FORMAT {:?}", f));
455 }
456
457 parts.push(")".to_string());
458 parts.push(format!("SELECT * FROM {}", cte.name));
459
460 parts.join("\n")
461 }
462 }
463}
464
465fn format_select_statement(stmt: &SelectStatement) -> String {
466 let mut parts = vec!["SELECT".to_string()];
467
468 if stmt.select_items.is_empty() {
470 parts.push(" *".to_string());
471 } else {
472 for (i, item) in stmt.select_items.iter().enumerate() {
473 let prefix = if i == 0 { " " } else { " , " };
474 match item {
475 SelectItem::Star { .. } => parts.push(format!("{}*", prefix)),
476 SelectItem::Column { column: col, .. } => {
477 parts.push(format!("{}{}", prefix, col.name));
478 }
479 SelectItem::Expression { expr, alias, .. } => {
480 let expr_str = format_expr(expr);
481 parts.push(format!("{}{} AS {}", prefix, expr_str, alias));
482 }
483 }
484 }
485 }
486
487 if let Some(ref table) = stmt.from_table {
489 parts.push(format!("FROM {}", table));
490 }
491
492 if let Some(ref where_clause) = stmt.where_clause {
494 parts.push("WHERE".to_string());
495 for (i, condition) in where_clause.conditions.iter().enumerate() {
496 let connector = if i > 0 {
497 condition
498 .connector
499 .as_ref()
500 .map(|op| match op {
501 crate::sql::parser::ast::LogicalOp::And => "AND",
502 crate::sql::parser::ast::LogicalOp::Or => "OR",
503 })
504 .unwrap_or("AND")
505 } else {
506 ""
507 };
508 let expr_str = format_expr(&condition.expr);
509 if i == 0 {
510 parts.push(format!(" {}", expr_str));
511 } else {
512 parts.push(format!(" {} {}", connector, expr_str));
513 }
514 }
515 }
516
517 if let Some(limit) = stmt.limit {
519 parts.push(format!("LIMIT {}", limit));
520 }
521
522 parts.join("\n")
523}
524
525fn format_expr(expr: &crate::sql::parser::ast::SqlExpression) -> String {
528 crate::sql::parser::ast_formatter::format_expression(expr)
529}
530
531pub fn find_query_context(ast: &SelectStatement, line: usize, _column: usize) -> QueryContext {
533 for (idx, cte) in ast.ctes.iter().enumerate() {
535 let cte_start = 1 + (idx * 5);
538 let cte_end = cte_start + 4;
539
540 if line >= cte_start && line <= cte_end {
541 return QueryContext {
542 context_type: "CTE".to_string(),
543 cte_name: Some(cte.name.clone()),
544 cte_index: Some(idx),
545 query_bounds: QueryBounds {
546 start_line: cte_start,
547 end_line: cte_end,
548 start_offset: 0,
549 end_offset: 0,
550 },
551 parent_query_bounds: Some(QueryBounds {
552 start_line: 1,
553 end_line: 100, start_offset: 0,
555 end_offset: 0,
556 }),
557 can_execute_independently: !matches!(cte.cte_type, CTEType::Web(_)),
558 };
559 }
560 }
561
562 QueryContext {
564 context_type: "main_query".to_string(),
565 cte_name: None,
566 cte_index: None,
567 query_bounds: QueryBounds {
568 start_line: 1,
569 end_line: 100, start_offset: 0,
571 end_offset: 0,
572 },
573 parent_query_bounds: None,
574 can_execute_independently: true,
575 }
576}
577
578#[cfg(test)]
579mod tests {
580 use super::*;
581 use crate::sql::recursive_parser::Parser;
582
583 #[test]
584 fn test_analyze_simple_query() {
585 let sql = "SELECT * FROM trades WHERE price > 100";
586 let mut parser = Parser::new(sql);
587 let ast = parser.parse().unwrap();
588
589 let analysis = analyze_query(&ast, sql);
590
591 assert!(analysis.valid);
592 assert_eq!(analysis.query_type, "SELECT");
593 assert!(analysis.has_star);
594 assert_eq!(analysis.star_locations.len(), 1);
595 assert_eq!(analysis.tables, vec!["trades"]);
596 }
597
598 #[test]
599 fn test_analyze_cte_query() {
600 let sql = "WITH trades AS (SELECT * FROM raw_trades) SELECT symbol FROM trades";
601 let mut parser = Parser::new(sql);
602 let ast = parser.parse().unwrap();
603
604 let analysis = analyze_query(&ast, sql);
605
606 assert!(analysis.valid);
607 assert_eq!(analysis.ctes.len(), 1);
608 assert_eq!(analysis.ctes[0].name, "trades");
609 assert_eq!(analysis.ctes[0].cte_type, "Standard");
610 assert!(analysis.ctes[0].has_star);
611 }
612
613 #[test]
614 fn test_extract_cte() {
615 let sql =
616 "WITH trades AS (SELECT * FROM raw_trades WHERE price > 100) SELECT * FROM trades";
617 let mut parser = Parser::new(sql);
618 let ast = parser.parse().unwrap();
619
620 let extracted = extract_cte(&ast, "trades").unwrap();
621
622 assert!(extracted.contains("SELECT"));
623 assert!(extracted.contains("raw_trades"));
624 assert!(extracted.contains("price > 100"));
625 }
626}