1pub mod statement_dependencies;
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9use crate::sql::parser::ast::{CTEType, SelectItem, SelectStatement, CTE};
10
11#[derive(Serialize, Deserialize, Debug)]
13pub struct QueryAnalysis {
14 pub valid: bool,
16 pub query_type: String,
18 pub has_star: bool,
20 pub star_locations: Vec<StarLocation>,
22 pub tables: Vec<String>,
24 pub columns: Vec<String>,
26 pub ctes: Vec<CteAnalysis>,
28 pub from_clause: Option<FromClauseInfo>,
30 pub where_clause: Option<WhereClauseInfo>,
32 pub errors: Vec<String>,
34}
35
36#[derive(Serialize, Deserialize, Debug, Clone)]
38pub struct StarLocation {
39 pub line: usize,
41 pub column: usize,
43 pub context: String,
45}
46
47#[derive(Serialize, Deserialize, Debug, Clone)]
49pub struct CteAnalysis {
50 pub name: String,
52 pub cte_type: String,
54 pub start_line: usize,
56 pub end_line: usize,
58 pub start_offset: usize,
60 pub end_offset: usize,
62 pub has_star: bool,
64 pub columns: Vec<String>,
66 pub web_config: Option<WebCteConfig>,
68}
69
70#[derive(Serialize, Deserialize, Debug, Clone)]
72pub struct WebCteConfig {
73 pub url: String,
75 pub method: String,
77 pub headers: Vec<(String, String)>,
79 pub format: Option<String>,
81}
82
83#[derive(Serialize, Deserialize, Debug, Clone)]
85pub struct FromClauseInfo {
86 pub source_type: String,
88 pub name: Option<String>,
90}
91
92#[derive(Serialize, Deserialize, Debug, Clone)]
94pub struct WhereClauseInfo {
95 pub present: bool,
97 pub columns_referenced: Vec<String>,
99}
100
101#[derive(Serialize, Deserialize, Debug)]
103pub struct ColumnExpansion {
104 pub original_query: String,
106 pub expanded_query: String,
108 pub columns: Vec<ColumnInfo>,
110 pub expansion_count: usize,
112 pub cte_columns: HashMap<String, Vec<String>>,
114}
115
116#[derive(Serialize, Deserialize, Debug, Clone)]
118pub struct ColumnInfo {
119 pub name: String,
121 pub data_type: String,
123}
124
125#[derive(Serialize, Deserialize, Debug)]
127pub struct QueryContext {
128 pub context_type: String,
130 pub cte_name: Option<String>,
132 pub cte_index: Option<usize>,
134 pub query_bounds: QueryBounds,
136 pub parent_query_bounds: Option<QueryBounds>,
138 pub can_execute_independently: bool,
140}
141
142#[derive(Serialize, Deserialize, Debug, Clone)]
144pub struct QueryBounds {
145 pub start_line: usize,
147 pub end_line: usize,
149 pub start_offset: usize,
151 pub end_offset: usize,
153}
154
155pub fn analyze_query(ast: &SelectStatement, _sql: &str) -> QueryAnalysis {
157 let mut analysis = QueryAnalysis {
158 valid: true,
159 query_type: "SELECT".to_string(),
160 has_star: false,
161 star_locations: vec![],
162 tables: vec![],
163 columns: vec![],
164 ctes: vec![],
165 from_clause: None,
166 where_clause: None,
167 errors: vec![],
168 };
169
170 for cte in &ast.ctes {
172 analysis.ctes.push(analyze_cte(cte));
173 }
174
175 for item in &ast.select_items {
177 if matches!(item, SelectItem::Star) {
178 analysis.has_star = true;
179 analysis.star_locations.push(StarLocation {
180 line: 1, column: 8,
182 context: "main_query".to_string(),
183 });
184 }
185 }
186
187 if let Some(ref table) = ast.from_table {
189 let table_name: String = table.clone();
190 analysis.tables.push(table_name.clone());
191 analysis.from_clause = Some(FromClauseInfo {
192 source_type: "table".to_string(),
193 name: Some(table_name),
194 });
195 } else if ast.from_subquery.is_some() {
196 analysis.from_clause = Some(FromClauseInfo {
197 source_type: "subquery".to_string(),
198 name: None,
199 });
200 }
201
202 if let Some(ref where_clause) = ast.where_clause {
204 let mut columns = vec![];
205 for condition in &where_clause.conditions {
207 if let Some(col) = extract_column_from_expr(&condition.expr) {
210 if !columns.contains(&col) {
211 columns.push(col);
212 }
213 }
214 }
215
216 analysis.where_clause = Some(WhereClauseInfo {
217 present: true,
218 columns_referenced: columns,
219 });
220 }
221
222 for item in &ast.select_items {
224 if let SelectItem::Column(col_ref) = item {
225 if !analysis.columns.contains(&col_ref.name) {
226 analysis.columns.push(col_ref.name.clone());
227 }
228 }
229 }
230
231 analysis
232}
233
234fn analyze_cte(cte: &CTE) -> CteAnalysis {
235 let cte_type_str = match &cte.cte_type {
236 CTEType::Standard(_) => "Standard",
237 CTEType::Web(_) => "WEB",
238 };
239
240 let mut has_star = false;
241 let mut web_config = None;
242
243 match &cte.cte_type {
244 CTEType::Standard(stmt) => {
245 for item in &stmt.select_items {
247 if matches!(item, SelectItem::Star) {
248 has_star = true;
249 break;
250 }
251 }
252 }
253 CTEType::Web(web_spec) => {
254 let method_str = match &web_spec.method {
255 Some(m) => format!("{:?}", m),
256 None => "GET".to_string(),
257 };
258 web_config = Some(WebCteConfig {
259 url: web_spec.url.clone(),
260 method: method_str,
261 headers: web_spec.headers.clone(),
262 format: web_spec.format.as_ref().map(|f| format!("{:?}", f)),
263 });
264 }
265 }
266
267 CteAnalysis {
268 name: cte.name.clone(),
269 cte_type: cte_type_str.to_string(),
270 start_line: 1, end_line: 1, start_offset: 0,
273 end_offset: 0,
274 has_star,
275 columns: vec![], web_config,
277 }
278}
279
280fn extract_column_from_expr(expr: &crate::sql::parser::ast::SqlExpression) -> Option<String> {
281 use crate::sql::parser::ast::SqlExpression;
282
283 match expr {
284 SqlExpression::Column(col_ref) => Some(col_ref.name.clone()),
285 SqlExpression::BinaryOp { left, right, .. } => {
286 extract_column_from_expr(left).or_else(|| extract_column_from_expr(right))
288 }
289 SqlExpression::FunctionCall { args, .. } => {
290 args.first().and_then(|arg| extract_column_from_expr(arg))
292 }
293 _ => None,
294 }
295}
296
297pub fn extract_cte(ast: &SelectStatement, cte_name: &str) -> Option<String> {
301 let mut target_index = None;
303 for (idx, cte) in ast.ctes.iter().enumerate() {
304 if cte.name == cte_name {
305 target_index = Some(idx);
306 break;
307 }
308 }
309
310 let target_index = target_index?;
311
312 let mut parts = vec![];
314
315 parts.push("WITH".to_string());
317
318 for (idx, cte) in ast.ctes.iter().enumerate() {
319 if idx > target_index {
320 break; }
322
323 let prefix = if idx == 0 { "" } else { "," };
325
326 match &cte.cte_type {
327 CTEType::Standard(stmt) => {
328 parts.push(format!("{} {} AS (", prefix, cte.name));
329 parts.push(indent_query(&format_select_statement(stmt), 2));
330 parts.push(")".to_string());
331 }
332 CTEType::Web(web_spec) => {
333 parts.push(format!("{} WEB {} AS (", prefix, cte.name));
334 parts.push(format!(" URL '{}'", web_spec.url));
335
336 if let Some(ref m) = web_spec.method {
337 parts.push(format!(" METHOD {:?}", m));
338 }
339
340 if let Some(ref f) = web_spec.format {
341 parts.push(format!(" FORMAT {:?}", f));
342 }
343
344 if let Some(cache) = web_spec.cache_seconds {
345 parts.push(format!(" CACHE {}", cache));
346 }
347
348 if !web_spec.headers.is_empty() {
349 parts.push(" HEADERS (".to_string());
350 for (i, (k, v)) in web_spec.headers.iter().enumerate() {
351 let comma = if i < web_spec.headers.len() - 1 {
352 ","
353 } else {
354 ""
355 };
356 parts.push(format!(" '{}': '{}'{}", k, v, comma));
357 }
358 parts.push(" )".to_string());
359 }
360
361 for (field_name, file_path) in &web_spec.form_files {
363 parts.push(format!(" FORM_FILE '{}' '{}'", field_name, file_path));
364 }
365
366 for (field_name, value) in &web_spec.form_fields {
368 let trimmed_value = value.trim();
369 if (trimmed_value.starts_with('{') && trimmed_value.ends_with('}'))
371 || (trimmed_value.starts_with('[') && trimmed_value.ends_with(']'))
372 {
373 parts.push(format!(
375 " FORM_FIELD '{}' $JSON${}$JSON$",
376 field_name, trimmed_value
377 ));
378 } else {
379 parts.push(format!(" FORM_FIELD '{}' '{}'", field_name, value));
381 }
382 }
383
384 if let Some(ref b) = web_spec.body {
385 let trimmed_body = b.trim();
387 if (trimmed_body.starts_with('{') && trimmed_body.ends_with('}'))
388 || (trimmed_body.starts_with('[') && trimmed_body.ends_with(']'))
389 {
390 parts.push(format!(" BODY $JSON${}$JSON$", trimmed_body));
391 } else {
392 parts.push(format!(" BODY '{}'", b));
393 }
394 }
395
396 if let Some(ref jp) = web_spec.json_path {
397 parts.push(format!(" JSON_PATH '{}'", jp));
398 }
399
400 parts.push(")".to_string());
401 }
402 }
403 }
404
405 parts.push(format!("SELECT * FROM {}", cte_name));
407
408 Some(parts.join("\n"))
409}
410
411fn indent_query(query: &str, spaces: usize) -> String {
412 let indent = " ".repeat(spaces);
413 query
414 .lines()
415 .map(|line| format!("{}{}", indent, line))
416 .collect::<Vec<_>>()
417 .join("\n")
418}
419
420fn format_cte_as_query(cte: &CTE) -> String {
421 match &cte.cte_type {
422 CTEType::Standard(stmt) => {
423 format_select_statement(stmt)
426 }
427 CTEType::Web(web_spec) => {
428 let mut parts = vec![
430 format!("WITH WEB {} AS (", cte.name),
431 format!(" URL '{}'", web_spec.url),
432 ];
433
434 if let Some(ref m) = web_spec.method {
435 parts.push(format!(" METHOD {:?}", m));
436 }
437
438 if !web_spec.headers.is_empty() {
439 parts.push(" HEADERS (".to_string());
440 for (k, v) in &web_spec.headers {
441 parts.push(format!(" '{}' = '{}'", k, v));
442 }
443 parts.push(" )".to_string());
444 }
445
446 if let Some(ref b) = web_spec.body {
447 parts.push(format!(" BODY '{}'", b));
448 }
449
450 if let Some(ref f) = web_spec.format {
451 parts.push(format!(" FORMAT {:?}", f));
452 }
453
454 parts.push(")".to_string());
455 parts.push(format!("SELECT * FROM {}", cte.name));
456
457 parts.join("\n")
458 }
459 }
460}
461
462fn format_select_statement(stmt: &SelectStatement) -> String {
463 let mut parts = vec!["SELECT".to_string()];
464
465 if stmt.select_items.is_empty() {
467 parts.push(" *".to_string());
468 } else {
469 for (i, item) in stmt.select_items.iter().enumerate() {
470 let prefix = if i == 0 { " " } else { " , " };
471 match item {
472 SelectItem::Star => parts.push(format!("{}*", prefix)),
473 SelectItem::Column(col) => {
474 parts.push(format!("{}{}", prefix, col.name));
475 }
476 SelectItem::Expression { expr, alias } => {
477 let expr_str = format_expr(expr);
478 parts.push(format!("{}{} AS {}", prefix, expr_str, alias));
479 }
480 }
481 }
482 }
483
484 if let Some(ref table) = stmt.from_table {
486 parts.push(format!("FROM {}", table));
487 }
488
489 if let Some(ref where_clause) = stmt.where_clause {
491 parts.push("WHERE".to_string());
492 for (i, condition) in where_clause.conditions.iter().enumerate() {
493 let connector = if i > 0 {
494 condition
495 .connector
496 .as_ref()
497 .map(|op| match op {
498 crate::sql::parser::ast::LogicalOp::And => "AND",
499 crate::sql::parser::ast::LogicalOp::Or => "OR",
500 })
501 .unwrap_or("AND")
502 } else {
503 ""
504 };
505 let expr_str = format_expr(&condition.expr);
506 if i == 0 {
507 parts.push(format!(" {}", expr_str));
508 } else {
509 parts.push(format!(" {} {}", connector, expr_str));
510 }
511 }
512 }
513
514 if let Some(limit) = stmt.limit {
516 parts.push(format!("LIMIT {}", limit));
517 }
518
519 parts.join("\n")
520}
521
522fn format_expr(expr: &crate::sql::parser::ast::SqlExpression) -> String {
525 crate::sql::parser::ast_formatter::format_expression(expr)
526}
527
528pub fn find_query_context(ast: &SelectStatement, line: usize, _column: usize) -> QueryContext {
530 for (idx, cte) in ast.ctes.iter().enumerate() {
532 let cte_start = 1 + (idx * 5);
535 let cte_end = cte_start + 4;
536
537 if line >= cte_start && line <= cte_end {
538 return QueryContext {
539 context_type: "CTE".to_string(),
540 cte_name: Some(cte.name.clone()),
541 cte_index: Some(idx),
542 query_bounds: QueryBounds {
543 start_line: cte_start,
544 end_line: cte_end,
545 start_offset: 0,
546 end_offset: 0,
547 },
548 parent_query_bounds: Some(QueryBounds {
549 start_line: 1,
550 end_line: 100, start_offset: 0,
552 end_offset: 0,
553 }),
554 can_execute_independently: !matches!(cte.cte_type, CTEType::Web(_)),
555 };
556 }
557 }
558
559 QueryContext {
561 context_type: "main_query".to_string(),
562 cte_name: None,
563 cte_index: None,
564 query_bounds: QueryBounds {
565 start_line: 1,
566 end_line: 100, start_offset: 0,
568 end_offset: 0,
569 },
570 parent_query_bounds: None,
571 can_execute_independently: true,
572 }
573}
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578 use crate::sql::recursive_parser::Parser;
579
580 #[test]
581 fn test_analyze_simple_query() {
582 let sql = "SELECT * FROM trades WHERE price > 100";
583 let mut parser = Parser::new(sql);
584 let ast = parser.parse().unwrap();
585
586 let analysis = analyze_query(&ast, sql);
587
588 assert!(analysis.valid);
589 assert_eq!(analysis.query_type, "SELECT");
590 assert!(analysis.has_star);
591 assert_eq!(analysis.star_locations.len(), 1);
592 assert_eq!(analysis.tables, vec!["trades"]);
593 }
594
595 #[test]
596 fn test_analyze_cte_query() {
597 let sql = "WITH trades AS (SELECT * FROM raw_trades) SELECT symbol FROM trades";
598 let mut parser = Parser::new(sql);
599 let ast = parser.parse().unwrap();
600
601 let analysis = analyze_query(&ast, sql);
602
603 assert!(analysis.valid);
604 assert_eq!(analysis.ctes.len(), 1);
605 assert_eq!(analysis.ctes[0].name, "trades");
606 assert_eq!(analysis.ctes[0].cte_type, "Standard");
607 assert!(analysis.ctes[0].has_star);
608 }
609
610 #[test]
611 fn test_extract_cte() {
612 let sql =
613 "WITH trades AS (SELECT * FROM raw_trades WHERE price > 100) SELECT * FROM trades";
614 let mut parser = Parser::new(sql);
615 let ast = parser.parse().unwrap();
616
617 let extracted = extract_cte(&ast, "trades").unwrap();
618
619 assert!(extracted.contains("SELECT"));
620 assert!(extracted.contains("raw_trades"));
621 assert!(extracted.contains("price > 100"));
622 }
623}