1use std::collections::{HashMap, HashSet};
2use std::sync::LazyLock;
3
4use regex::Regex;
5
6use crate::annotations::extract_annotations;
7use crate::error::Result;
8use crate::ir::{ColumnDef, EnumDef, QueryDef, SqlType, SqlTypeCategory, TableDef};
9use crate::parser::joins::{has_outer_join, resolve_multi_table_columns};
10use crate::parser::{
11 DatabaseParser, build_params, ensure_supported_select_expr, make_unknown_column,
12 split_column_defs, split_query_blocks,
13};
14
15static ENUM_DEF_RE: LazyLock<Regex> = LazyLock::new(|| {
18 Regex::new(
19 r"(?i)CREATE\s+TYPE\s+(\w+)\s+AS\s+ENUM\s*\(\s*((?:'[^']*'(?:\s*,\s*'[^']*')*)?)\s*\)",
20 )
21 .unwrap()
22});
23
24static ENUM_VAL_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"'([^']*)'").unwrap());
25
26static CONSTRAINT_RE: LazyLock<Regex> = LazyLock::new(|| {
27 Regex::new(r"(?i)^(PRIMARY\s+KEY|CONSTRAINT|UNIQUE|CHECK|FOREIGN\s+KEY)").unwrap()
28});
29
30static COL_NAME_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(\w+)\s+").unwrap());
31
32static COL_TYPE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(\w+(?:\[\])?)").unwrap());
33
34static NOT_NULL_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)\bNOT\s+NULL\b").unwrap());
35
36static DEFAULT_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)\bDEFAULT\b").unwrap());
37
38static PK_INLINE_RE: LazyLock<Regex> =
39 LazyLock::new(|| Regex::new(r"(?i)\bPRIMARY\s+KEY\b").unwrap());
40
41static UNIQUE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)\bUNIQUE\b").unwrap());
42
43static TABLE_RE: LazyLock<Regex> = LazyLock::new(|| {
44 Regex::new(r"(?is)CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)\s*\(([\s\S]*?)\)\s*;")
45 .unwrap()
46});
47
48static TABLE_PK_RE: LazyLock<Regex> =
49 LazyLock::new(|| Regex::new(r"(?i)^PRIMARY\s+KEY\s*\(\s*([\w\s,]+)\s*\)").unwrap());
50
51static PARAM_INDEX_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\$(\d+)").unwrap());
52
53static INSERT_RE: LazyLock<Regex> = LazyLock::new(|| {
54 Regex::new(r"(?i)INSERT\s+INTO\s+\w+\s*\(\s*([\w\s,]+)\s*\)\s*VALUES\s*\(\s*([\$\d\s,]+)\s*\)")
55 .unwrap()
56});
57
58static WHERE_PARAM_RE: LazyLock<Regex> = LazyLock::new(|| {
59 Regex::new(
60 r"(?i)(?:(\w+)\s*\(\s*(\w+)\s*\)|(\w+))\s*(?:=|!=|<>|<=?|>=?|(?:NOT\s+)?(?:I?LIKE|IN|IS))\s*\$(\d+)",
61 )
62 .unwrap()
63});
64
65static FROM_TABLE_RE: LazyLock<Regex> =
66 LazyLock::new(|| Regex::new(r"(?i)(?:FROM|INTO|UPDATE)\s+(\w+)").unwrap());
67
68static RETURNING_RE: LazyLock<Regex> =
69 LazyLock::new(|| Regex::new(r"(?i)\bRETURNING\s+([\s\S]+?)(?:;?\s*)$").unwrap());
70
71static SELECT_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)^\s*SELECT\b").unwrap());
72
73static SELECT_COLS_RE: LazyLock<Regex> =
74 LazyLock::new(|| Regex::new(r"(?i)SELECT\s+([\s\S]+?)\s+FROM\b").unwrap());
75
76static ALIAS_RE: LazyLock<Regex> =
77 LazyLock::new(|| Regex::new(r"(?i)^(\w+)\s+as\s+(\w+)$").unwrap());
78
79fn type_category(normalized: &str) -> Option<SqlTypeCategory> {
82 match normalized {
83 "text" | "varchar" | "char" | "character varying" | "character" | "name" => {
84 Some(SqlTypeCategory::String)
85 }
86 "integer" | "int" | "int2" | "int4" | "int8" | "smallint" | "bigint" | "serial"
87 | "bigserial" | "real" | "double precision" | "numeric" | "decimal" | "float"
88 | "float4" | "float8" => Some(SqlTypeCategory::Number),
89 "boolean" | "bool" => Some(SqlTypeCategory::Boolean),
90 "timestamp"
91 | "timestamptz"
92 | "date"
93 | "time"
94 | "timetz"
95 | "timestamp without time zone"
96 | "timestamp with time zone" => Some(SqlTypeCategory::Date),
97 "json" | "jsonb" => Some(SqlTypeCategory::Json),
98 "uuid" => Some(SqlTypeCategory::Uuid),
99 "bytea" => Some(SqlTypeCategory::Binary),
100 _ => None,
101 }
102}
103
104fn is_serial(normalized: &str) -> bool {
105 matches!(normalized, "serial" | "bigserial")
106}
107
108fn resolve_sql_type(raw: &str, enum_names: &HashSet<String>) -> SqlType {
109 let trimmed = raw.trim();
110
111 if let Some(base_raw) = trimmed.strip_suffix("[]") {
113 let element = resolve_sql_type(base_raw, enum_names);
114 return SqlType {
115 raw: trimmed.to_string(),
116 normalized: trimmed.to_lowercase(),
117 category: element.category.clone(),
118 element_type: Some(Box::new(element)),
119 enum_name: None,
120 enum_values: None,
121 json_shape: None,
122 };
123 }
124
125 let normalized = trimmed.to_lowercase();
126
127 if let Some(cat) = type_category(&normalized) {
128 return SqlType {
129 raw: trimmed.to_string(),
130 normalized,
131 category: cat,
132 element_type: None,
133 enum_name: None,
134 enum_values: None,
135 json_shape: None,
136 };
137 }
138
139 if enum_names.contains(&normalized) {
141 return SqlType {
142 raw: trimmed.to_string(),
143 normalized: normalized.clone(),
144 category: SqlTypeCategory::Enum,
145 element_type: None,
146 enum_name: Some(normalized),
147 enum_values: None,
148 json_shape: None,
149 };
150 }
151
152 SqlType {
153 raw: trimmed.to_string(),
154 normalized,
155 category: SqlTypeCategory::Unknown,
156 element_type: None,
157 enum_name: None,
158 enum_values: None,
159 json_shape: None,
160 }
161}
162
163fn parse_enum_defs(sql: &str) -> Vec<EnumDef> {
166 let mut enums = Vec::new();
167 for cap in ENUM_DEF_RE.captures_iter(sql) {
168 let name = cap[1].to_lowercase();
169 let values_raw = &cap[2];
170 let values: Vec<String> = ENUM_VAL_RE
171 .captures_iter(values_raw)
172 .map(|v| v[1].to_string())
173 .collect();
174 enums.push(EnumDef { name, values });
175 }
176 enums
177}
178
179const MULTI_WORD_TYPES: &[&str] = &[
182 "character varying",
183 "double precision",
184 "timestamp without time zone",
185 "timestamp with time zone",
186];
187
188struct ParsedColumn {
189 col: ColumnDef,
190 is_pk: bool,
191 is_unique: bool,
192}
193
194fn parse_column_line(line: &str, enum_names: &HashSet<String>) -> Option<ParsedColumn> {
195 let line = line.trim();
196 if line.is_empty() {
197 return None;
198 }
199
200 if CONSTRAINT_RE.is_match(line) {
202 return None;
203 }
204
205 let name_cap = COL_NAME_RE.captures(line)?;
207 let col_name = name_cap[1].to_lowercase();
208 let after_name = &line[name_cap[0].len()..];
209
210 let mut raw_type: Option<String> = None;
212 for mwt in MULTI_WORD_TYPES {
213 if after_name.to_lowercase().starts_with(mwt) {
214 raw_type = Some(mwt.to_string());
215 break;
216 }
217 }
218 if raw_type.is_none()
219 && let Some(cap) = COL_TYPE_RE.captures(after_name)
220 {
221 raw_type = Some(cap[1].to_string());
222 }
223 let raw_type = raw_type.unwrap_or_else(|| "unknown".to_string());
224
225 let rest = &after_name[raw_type.len()..];
226
227 let is_not_null = NOT_NULL_RE.is_match(rest);
228 let has_default_kw = DEFAULT_RE.is_match(rest);
229 let is_serial_type = is_serial(&raw_type.to_lowercase());
230 let is_pk = PK_INLINE_RE.is_match(rest);
231 let is_unique = UNIQUE_RE.is_match(rest);
232
233 let sql_type = resolve_sql_type(&raw_type, enum_names);
234
235 Some(ParsedColumn {
236 col: ColumnDef {
237 name: col_name,
238 alias: None,
239 source_table: None,
240 sql_type,
241 nullable: !is_not_null,
242 has_default: has_default_kw || is_serial_type,
243 },
244 is_pk,
245 is_unique,
246 })
247}
248
249fn parse_schema_tables(sql: &str, enum_names: &HashSet<String>) -> Vec<TableDef> {
250 let mut tables = Vec::new();
251
252 for cap in TABLE_RE.captures_iter(sql) {
253 let table_name = cap[1].to_lowercase();
254 let body = &cap[2];
255
256 let mut columns = Vec::new();
257 let mut primary_key: Vec<String> = Vec::new();
258 let mut unique_constraints: Vec<Vec<String>> = Vec::new();
259
260 let raw_lines: Vec<&str> = body.lines().collect();
262 let mut pending_comment = String::new();
263 let mut non_comment_buf = String::new();
264 let mut comment_map: HashMap<usize, String> = HashMap::new();
265
266 for raw_line in &raw_lines {
267 let trimmed = raw_line.trim();
268 if trimmed.starts_with("--") {
269 if !pending_comment.is_empty() {
270 pending_comment.push('\n');
271 }
272 pending_comment.push_str(trimmed);
273 } else {
274 let before = split_column_defs(&non_comment_buf)
275 .iter()
276 .filter(|d| !d.is_empty())
277 .count();
278 if !non_comment_buf.is_empty() {
279 non_comment_buf.push('\n');
280 }
281 non_comment_buf.push_str(raw_line);
282 let after = split_column_defs(&non_comment_buf)
283 .iter()
284 .filter(|d| !d.is_empty())
285 .count();
286
287 if after > before && !pending_comment.is_empty() {
288 comment_map.insert(before, pending_comment.clone());
289 pending_comment.clear();
290 } else if after == before {
291 } else {
293 pending_comment.clear();
294 }
295 }
296 }
297
298 let lines = split_column_defs(&non_comment_buf);
299
300 for (i, line) in lines.iter().enumerate() {
301 let trimmed = line.trim();
302
303 if let Some(pk_cap) = TABLE_PK_RE.captures(trimmed) {
305 for col in pk_cap[1].split(',') {
306 primary_key.push(col.trim().to_lowercase());
307 }
308 continue;
309 }
310
311 let Some(mut parsed) = parse_column_line(trimmed, enum_names) else {
312 continue;
313 };
314
315 if let Some(comment) = comment_map.get(&i) {
317 let (_, ann) = extract_annotations(comment);
318 if let Some(values) = ann.enums.get(&parsed.col.name) {
319 parsed.col.sql_type.category = SqlTypeCategory::Enum;
320 parsed.col.sql_type.enum_values = Some(values.clone());
321 }
322 if let Some(shape) = ann.json_shapes.get(&parsed.col.name) {
323 parsed.col.sql_type.json_shape = Some(shape.clone());
324 }
325 }
326
327 if parsed.is_pk {
328 primary_key.push(parsed.col.name.clone());
329 }
330 if parsed.is_unique {
331 unique_constraints.push(vec![parsed.col.name.clone()]);
332 }
333 columns.push(parsed.col);
334 }
335
336 for col in &mut columns {
338 if primary_key.contains(&col.name) {
339 col.nullable = false;
340 if is_serial(&col.sql_type.normalized) {
341 col.has_default = true;
342 }
343 }
344 }
345
346 tables.push(TableDef {
347 name: table_name,
348 columns,
349 primary_key,
350 unique_constraints,
351 });
352 }
353
354 tables
355}
356
357fn extract_param_indices(sql: &str) -> Vec<u32> {
360 let mut indices: HashSet<u32> = HashSet::new();
361 for cap in PARAM_INDEX_RE.captures_iter(sql) {
362 if let Ok(idx) = cap[1].parse::<u32>() {
363 indices.insert(idx);
364 }
365 }
366 let mut sorted: Vec<u32> = indices.into_iter().collect();
367 sorted.sort();
368 sorted
369}
370
371fn infer_param_columns(sql: &str) -> HashMap<u32, String> {
372 let mut result = HashMap::new();
373
374 if let Some(cap) = INSERT_RE.captures(sql) {
376 let cols: Vec<String> = cap[1].split(',').map(|s| s.trim().to_lowercase()).collect();
377 let params: Vec<u32> = PARAM_INDEX_RE
378 .captures_iter(&cap[2])
379 .filter_map(|m| m[1].parse().ok())
380 .collect();
381
382 for (i, idx) in params.iter().enumerate() {
383 if i < cols.len() {
384 result.insert(*idx, cols[i].clone());
385 }
386 }
387 return result;
388 }
389
390 let sql_keywords: HashSet<&str> = [
392 "not", "and", "or", "where", "set", "when", "then", "else", "case", "between", "exists",
393 "any", "all", "some", "having",
394 ]
395 .into_iter()
396 .collect();
397
398 for cap in WHERE_PARAM_RE.captures_iter(sql) {
399 if let Ok(idx) = cap[4].parse::<u32>() {
400 if cap.get(1).is_some() && cap.get(2).is_some() {
401 result.insert(idx, cap[2].to_lowercase());
403 } else if let Some(m) = cap.get(3) {
404 let word = m.as_str().to_lowercase();
405 if !sql_keywords.contains(word.as_str()) {
406 result.insert(idx, word);
407 }
408 }
409 }
410 }
411
412 result
413}
414
415fn find_from_table<'a>(sql: &str, tables: &'a [TableDef]) -> Option<&'a TableDef> {
416 let cap = FROM_TABLE_RE.captures(sql)?;
417 let table_name = cap[1].to_lowercase();
418 tables.iter().find(|t| t.name == table_name)
419}
420
421fn resolve_returning_columns(sql: &str, table: Option<&TableDef>) -> Option<Vec<ColumnDef>> {
422 let cap = RETURNING_RE.captures(sql)?;
423 let cols_part = cap[1].trim();
424
425 if cols_part == "*" {
426 return Some(table.map(|t| t.columns.clone()).unwrap_or_default());
427 }
428
429 let table = table?;
430 Some(
431 cols_part
432 .split(',')
433 .map(|s| {
434 let name = s.trim().to_lowercase();
435 table
436 .columns
437 .iter()
438 .find(|c| c.name == name)
439 .cloned()
440 .unwrap_or_else(|| make_unknown_column(&name))
441 })
442 .collect(),
443 )
444}
445
446fn resolve_return_columns(
447 sql: &str,
448 table: Option<&TableDef>,
449 schema_tables: &[TableDef],
450 source_file: &str,
451) -> Result<Vec<ColumnDef>> {
452 if let Some(returning) = resolve_returning_columns(sql, table) {
454 return Ok(returning);
455 }
456
457 if !SELECT_RE.is_match(sql) {
458 return Ok(Vec::new());
459 }
460
461 let Some(cap) = SELECT_COLS_RE.captures(sql) else {
462 return Ok(Vec::new());
463 };
464 let cols_part = cap[1].trim();
465
466 if has_outer_join(sql) {
472 return resolve_multi_table_columns(cols_part, sql, schema_tables, source_file);
473 }
474
475 if cols_part == "*" {
476 return Ok(table.map(|t| t.columns.clone()).unwrap_or_default());
477 }
478
479 let Some(table) = table else {
480 return Ok(Vec::new());
481 };
482
483 let col_names: Vec<&str> = cols_part.split(',').map(|s| s.trim()).collect();
484
485 col_names
486 .iter()
487 .map(|&col_expr| -> Result<ColumnDef> {
488 ensure_supported_select_expr(col_expr, source_file)?;
489 let expr_lower = col_expr.to_lowercase();
490 if let Some(alias_cap) = ALIAS_RE.captures(&expr_lower) {
491 let actual = &alias_cap[1];
492 let alias = alias_cap[2].to_string();
493 Ok(table
494 .columns
495 .iter()
496 .find(|c| c.name == actual)
497 .map(|c| {
498 let mut col = c.clone();
499 col.alias = Some(alias);
500 col
501 })
502 .unwrap_or_else(|| make_unknown_column(actual)))
503 } else {
504 Ok(table
505 .columns
506 .iter()
507 .find(|c| c.name == expr_lower)
508 .cloned()
509 .unwrap_or_else(|| make_unknown_column(&expr_lower)))
510 }
511 })
512 .collect()
513}
514
515pub struct PostgresParser;
518
519impl PostgresParser {
520 pub fn new() -> Self {
521 Self
522 }
523}
524
525impl Default for PostgresParser {
526 fn default() -> Self {
527 Self::new()
528 }
529}
530
531impl DatabaseParser for PostgresParser {
532 fn parse_schema(&self, sql: &str) -> Result<(Vec<TableDef>, Vec<EnumDef>)> {
533 let enums = parse_enum_defs(sql);
534 let enum_names: HashSet<String> = enums.iter().map(|e| e.name.clone()).collect();
535 let tables = parse_schema_tables(sql, &enum_names);
536 Ok((tables, enums))
537 }
538
539 fn parse_queries(
540 &self,
541 sql: &str,
542 tables: &[TableDef],
543 enums: &[EnumDef],
544 source_file: &str,
545 ) -> Result<Vec<QueryDef>> {
546 let _ = enums; let blocks = split_query_blocks(sql);
548 let mut queries = Vec::new();
549
550 for block in blocks {
551 let table = find_from_table(&block.sql, tables);
552 let param_indices = extract_param_indices(&block.sql);
553 let inferred_cols = infer_param_columns(&block.sql);
554 let params = build_params(&block.comments, table, param_indices, inferred_cols);
555 let returns = resolve_return_columns(&block.sql, table, tables, source_file)?;
556
557 let clean_sql = block
558 .sql
559 .trim_end()
560 .trim_end_matches(';')
561 .trim()
562 .to_string();
563
564 queries.push(QueryDef {
565 name: block.name,
566 command: block.command,
567 sql: clean_sql,
568 params,
569 returns,
570 source_file: source_file.to_string(),
571 });
572 }
573
574 Ok(queries)
575 }
576}
577
578#[cfg(test)]
581mod tests {
582 use super::*;
583 use crate::ir::{QueryCommand, SqlTypeCategory};
584 use crate::parser::DatabaseParser;
585
586 const SCHEMA_SQL: &str = include_str!("../../../../tests/fixtures/schema.sql");
587 const QUERIES_SQL: &str = include_str!("../../../../tests/fixtures/queries/users.sql");
588
589 #[test]
590 fn parses_enum_type() {
591 let parser = PostgresParser::new();
592 let (_, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
593 assert_eq!(enums.len(), 1);
594 assert_eq!(enums[0].name, "user_status");
595 assert_eq!(enums[0].values, vec!["active", "inactive", "banned"]);
596 }
597
598 #[test]
599 fn parses_users_table() {
600 let parser = PostgresParser::new();
601 let (tables, _) = parser.parse_schema(SCHEMA_SQL).unwrap();
602 let users = tables.iter().find(|t| t.name == "users").unwrap();
603 assert_eq!(users.columns.len(), 7);
604 assert_eq!(users.primary_key, vec!["id"]);
605
606 let id_col = &users.columns[0];
607 assert_eq!(id_col.name, "id");
608 assert_eq!(id_col.sql_type.category, SqlTypeCategory::Number);
609 assert!(id_col.has_default); assert!(!id_col.nullable);
611
612 let bio_col = users.columns.iter().find(|c| c.name == "bio").unwrap();
613 assert!(bio_col.nullable);
614
615 let tags_col = users.columns.iter().find(|c| c.name == "tags").unwrap();
616 assert!(tags_col.sql_type.element_type.is_some());
617 }
618
619 #[test]
620 fn parses_posts_table() {
621 let parser = PostgresParser::new();
622 let (tables, _) = parser.parse_schema(SCHEMA_SQL).unwrap();
623 let posts = tables.iter().find(|t| t.name == "posts").unwrap();
624 assert_eq!(posts.columns.len(), 6);
625 }
626
627 #[test]
628 fn parses_get_user_query() {
629 let parser = PostgresParser::new();
630 let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
631 let queries = parser
632 .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
633 .unwrap();
634 let get_user = queries.iter().find(|q| q.name == "GetUser").unwrap();
635 assert_eq!(get_user.command, QueryCommand::One);
636 assert_eq!(get_user.params.len(), 1);
637 assert_eq!(get_user.params[0].name, "id");
638 assert_eq!(get_user.returns.len(), 7); }
640
641 #[test]
642 fn parses_list_users_partial_select() {
643 let parser = PostgresParser::new();
644 let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
645 let queries = parser
646 .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
647 .unwrap();
648 let list_users = queries.iter().find(|q| q.name == "ListUsers").unwrap();
649 assert_eq!(list_users.command, QueryCommand::Many);
650 assert_eq!(list_users.returns.len(), 3); }
652
653 #[test]
654 fn parses_create_user_exec() {
655 let parser = PostgresParser::new();
656 let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
657 let queries = parser
658 .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
659 .unwrap();
660 let create_user = queries.iter().find(|q| q.name == "CreateUser").unwrap();
661 assert_eq!(create_user.command, QueryCommand::Exec);
662 assert_eq!(create_user.params.len(), 3);
663 assert!(create_user.returns.is_empty());
664 }
665
666 #[test]
667 fn parses_delete_user_execresult() {
668 let parser = PostgresParser::new();
669 let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
670 let queries = parser
671 .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
672 .unwrap();
673 let delete_user = queries.iter().find(|q| q.name == "DeleteUser").unwrap();
674 assert_eq!(delete_user.command, QueryCommand::ExecResult);
675 }
676
677 #[test]
678 fn parses_param_overrides() {
679 let parser = PostgresParser::new();
680 let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
681 let queries = parser
682 .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
683 .unwrap();
684 let date_range = queries
685 .iter()
686 .find(|q| q.name == "ListUsersByDateRange")
687 .unwrap();
688 assert_eq!(date_range.params[0].name, "start_date");
689 assert_eq!(date_range.params[1].name, "end_date");
690 }
691
692 #[test]
693 fn resolve_type_maps_common_types() {
694 let enums = HashSet::new();
695
696 let text = resolve_sql_type("TEXT", &enums);
697 assert_eq!(text.category, SqlTypeCategory::String);
698
699 let int = resolve_sql_type("INTEGER", &enums);
700 assert_eq!(int.category, SqlTypeCategory::Number);
701
702 let bool_t = resolve_sql_type("BOOLEAN", &enums);
703 assert_eq!(bool_t.category, SqlTypeCategory::Boolean);
704
705 let ts = resolve_sql_type("TIMESTAMP", &enums);
706 assert_eq!(ts.category, SqlTypeCategory::Date);
707
708 let json = resolve_sql_type("JSONB", &enums);
709 assert_eq!(json.category, SqlTypeCategory::Json);
710
711 let uuid = resolve_sql_type("UUID", &enums);
712 assert_eq!(uuid.category, SqlTypeCategory::Uuid);
713
714 let bytea = resolve_sql_type("BYTEA", &enums);
715 assert_eq!(bytea.category, SqlTypeCategory::Binary);
716 }
717
718 #[test]
719 fn resolve_type_array() {
720 let enums = HashSet::new();
721 let arr = resolve_sql_type("TEXT[]", &enums);
722 assert_eq!(arr.category, SqlTypeCategory::String);
723 assert!(arr.element_type.is_some());
724 assert_eq!(arr.element_type.unwrap().category, SqlTypeCategory::String);
725 }
726
727 #[test]
728 fn resolve_type_enum() {
729 let mut enums = HashSet::new();
730 enums.insert("user_status".to_string());
731 let t = resolve_sql_type("user_status", &enums);
732 assert_eq!(t.category, SqlTypeCategory::Enum);
733 assert_eq!(t.enum_name, Some("user_status".to_string()));
734 }
735
736 #[test]
737 fn infer_insert_params() {
738 let sql = "INSERT INTO users (name, email, bio) VALUES ($1, $2, $3)";
739 let cols = infer_param_columns(sql);
740 assert_eq!(cols.get(&1), Some(&"name".to_string()));
741 assert_eq!(cols.get(&2), Some(&"email".to_string()));
742 assert_eq!(cols.get(&3), Some(&"bio".to_string()));
743 }
744
745 #[test]
746 fn infer_where_params() {
747 let sql = "SELECT * FROM users WHERE id = $1";
748 let cols = infer_param_columns(sql);
749 assert_eq!(cols.get(&1), Some(&"id".to_string()));
750 }
751
752 #[test]
753 fn split_query_blocks_basic() {
754 let blocks = split_query_blocks(
755 "-- name: GetUser :one\nSELECT * FROM users WHERE id = $1;\n\n-- name: ListUsers :many\nSELECT id, name FROM users;",
756 );
757 assert_eq!(blocks.len(), 2);
758 assert_eq!(blocks[0].name, "GetUser");
759 assert_eq!(blocks[1].name, "ListUsers");
760 }
761
762 #[test]
763 fn resolve_parser_postgres() {
764 let parser = crate::parser::resolve_parser("postgres");
765 assert!(parser.is_ok());
766 }
767
768 #[test]
769 fn resolve_parser_mysql() {
770 let parser = crate::parser::resolve_parser("mysql");
771 assert!(parser.is_ok());
772 }
773
774 #[test]
775 fn resolve_parser_sqlite() {
776 let parser = crate::parser::resolve_parser("sqlite");
777 assert!(parser.is_ok());
778 }
779
780 #[test]
781 fn resolve_parser_unknown() {
782 let parser = crate::parser::resolve_parser("oracle");
783 assert!(parser.is_err());
784 }
785
786 fn join_schema() -> &'static str {
789 r#"
790 CREATE TABLE users (
791 id INTEGER PRIMARY KEY,
792 name TEXT NOT NULL,
793 org_id INTEGER NOT NULL
794 );
795 CREATE TABLE orgs (
796 id INTEGER PRIMARY KEY,
797 slug TEXT NOT NULL
798 );
799 "#
800 }
801
802 #[test]
803 fn inner_join_resolves_qualified_columns() {
804 let parser = PostgresParser::new();
805 let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
806 let sql = "-- name: GetUserWithOrg :one\nSELECT users.name, orgs.slug FROM users INNER JOIN orgs ON users.org_id = orgs.id WHERE users.id = $1;";
807 let queries = parser.parse_queries(sql, &tables, &enums, "q.sql").unwrap();
808 assert_eq!(queries.len(), 1);
809 let q = &queries[0];
810 assert_eq!(q.returns.len(), 2);
811 assert_eq!(q.returns[0].name, "name");
812 assert_eq!(q.returns[0].source_table.as_deref(), Some("users"));
813 assert_eq!(q.returns[1].name, "slug");
814 assert_eq!(q.returns[1].source_table.as_deref(), Some("orgs"));
815 }
816
817 #[test]
818 fn inner_join_accepts_aliases_and_as() {
819 let parser = PostgresParser::new();
820 let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
821 let sql = "-- name: Listing :many\nSELECT u.id AS user_id, o.slug AS org_slug FROM users u INNER JOIN orgs o ON u.org_id = o.id;";
822 let queries = parser.parse_queries(sql, &tables, &enums, "q.sql").unwrap();
823 let q = &queries[0];
824 assert_eq!(q.returns[0].name, "id");
825 assert_eq!(q.returns[0].alias.as_deref(), Some("user_id"));
826 assert_eq!(q.returns[0].source_table.as_deref(), Some("users"));
827 assert_eq!(q.returns[1].alias.as_deref(), Some("org_slug"));
828 assert_eq!(q.returns[1].source_table.as_deref(), Some("orgs"));
829 }
830
831 #[test]
832 fn inner_join_rejects_select_star() {
833 let parser = PostgresParser::new();
834 let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
835 let sql = "-- name: Everything :many\nSELECT * FROM users INNER JOIN orgs ON users.org_id = orgs.id;";
836 let err = parser
837 .parse_queries(sql, &tables, &enums, "q.sql")
838 .unwrap_err();
839 assert!(
840 err.to_string()
841 .contains("SELECT * across multi-table JOINs")
842 );
843 }
844
845 #[test]
846 fn left_join_rejected_with_v12_pointer() {
847 let parser = PostgresParser::new();
848 let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
849 let sql = "-- name: WithLeft :many\nSELECT users.id FROM users LEFT JOIN orgs ON users.org_id = orgs.id;";
850 let err = parser
851 .parse_queries(sql, &tables, &enums, "q.sql")
852 .unwrap_err();
853 assert!(err.to_string().contains("v1.1 supports INNER JOIN only"));
854 }
855
856 #[test]
857 fn single_table_path_still_rejects_qualified_selects() {
858 let parser = PostgresParser::new();
862 let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
863 let sql = "-- name: Bad :one\nSELECT users.id FROM users WHERE users.id = $1;";
864 let err = parser
865 .parse_queries(sql, &tables, &enums, "q.sql")
866 .unwrap_err();
867 assert!(
868 err.to_string()
869 .contains("qualified select expressions are not supported")
870 );
871 }
872
873 #[test]
874 fn join_in_subquery_does_not_route_outer_to_multi_table() {
875 let parser = PostgresParser::new();
880 let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
881 let sql = "-- name: SubquerySafe :many\nSELECT id FROM users WHERE id IN (SELECT users.id FROM users INNER JOIN orgs ON users.org_id = orgs.id);";
882 let queries = parser.parse_queries(sql, &tables, &enums, "q.sql").unwrap();
883 assert_eq!(queries[0].returns.len(), 1);
884 assert_eq!(queries[0].returns[0].name, "id");
885 assert_eq!(queries[0].returns[0].source_table, None);
886 }
887}