Skip to main content

sqlcx_core/parser/
postgres.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::LazyLock;
3
4use regex::Regex;
5
6use crate::annotations::extract_annotations;
7use crate::error::Result;
8use crate::ir::{ColumnDef, EnumDef, QueryDef, SqlType, SqlTypeCategory, TableDef};
9use crate::parser::joins::{has_outer_join, resolve_multi_table_columns};
10use crate::parser::{
11    DatabaseParser, build_params, ensure_supported_select_expr, make_unknown_column,
12    split_column_defs, split_query_blocks,
13};
14
15// ── Static regex patterns ────────────────────────────────────────────────────
16
17static ENUM_DEF_RE: LazyLock<Regex> = LazyLock::new(|| {
18    Regex::new(
19        r"(?i)CREATE\s+TYPE\s+(\w+)\s+AS\s+ENUM\s*\(\s*((?:'[^']*'(?:\s*,\s*'[^']*')*)?)\s*\)",
20    )
21    .unwrap()
22});
23
24static ENUM_VAL_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"'([^']*)'").unwrap());
25
26static CONSTRAINT_RE: LazyLock<Regex> = LazyLock::new(|| {
27    Regex::new(r"(?i)^(PRIMARY\s+KEY|CONSTRAINT|UNIQUE|CHECK|FOREIGN\s+KEY)").unwrap()
28});
29
30static COL_NAME_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(\w+)\s+").unwrap());
31
32static COL_TYPE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"^(\w+(?:\[\])?)").unwrap());
33
34static NOT_NULL_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)\bNOT\s+NULL\b").unwrap());
35
36static DEFAULT_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)\bDEFAULT\b").unwrap());
37
38static PK_INLINE_RE: LazyLock<Regex> =
39    LazyLock::new(|| Regex::new(r"(?i)\bPRIMARY\s+KEY\b").unwrap());
40
41static UNIQUE_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)\bUNIQUE\b").unwrap());
42
43static TABLE_RE: LazyLock<Regex> = LazyLock::new(|| {
44    Regex::new(r"(?is)CREATE\s+TABLE\s+(?:IF\s+NOT\s+EXISTS\s+)?(\w+)\s*\(([\s\S]*?)\)\s*;")
45        .unwrap()
46});
47
48static TABLE_PK_RE: LazyLock<Regex> =
49    LazyLock::new(|| Regex::new(r"(?i)^PRIMARY\s+KEY\s*\(\s*([\w\s,]+)\s*\)").unwrap());
50
51static PARAM_INDEX_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\$(\d+)").unwrap());
52
53static INSERT_RE: LazyLock<Regex> = LazyLock::new(|| {
54    Regex::new(r"(?i)INSERT\s+INTO\s+\w+\s*\(\s*([\w\s,]+)\s*\)\s*VALUES\s*\(\s*([\$\d\s,]+)\s*\)")
55        .unwrap()
56});
57
58static WHERE_PARAM_RE: LazyLock<Regex> = LazyLock::new(|| {
59    Regex::new(
60        r"(?i)(?:(\w+)\s*\(\s*(\w+)\s*\)|(\w+))\s*(?:=|!=|<>|<=?|>=?|(?:NOT\s+)?(?:I?LIKE|IN|IS))\s*\$(\d+)",
61    )
62    .unwrap()
63});
64
65static FROM_TABLE_RE: LazyLock<Regex> =
66    LazyLock::new(|| Regex::new(r"(?i)(?:FROM|INTO|UPDATE)\s+(\w+)").unwrap());
67
68static RETURNING_RE: LazyLock<Regex> =
69    LazyLock::new(|| Regex::new(r"(?i)\bRETURNING\s+([\s\S]+?)(?:;?\s*)$").unwrap());
70
71static SELECT_RE: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"(?i)^\s*SELECT\b").unwrap());
72
73static SELECT_COLS_RE: LazyLock<Regex> =
74    LazyLock::new(|| Regex::new(r"(?i)SELECT\s+([\s\S]+?)\s+FROM\b").unwrap());
75
76static ALIAS_RE: LazyLock<Regex> =
77    LazyLock::new(|| Regex::new(r"(?i)^(\w+)\s+as\s+(\w+)$").unwrap());
78
79// ── Type mapping ─────────────────────────────────────────────────────────────
80
81fn type_category(normalized: &str) -> Option<SqlTypeCategory> {
82    match normalized {
83        "text" | "varchar" | "char" | "character varying" | "character" | "name" => {
84            Some(SqlTypeCategory::String)
85        }
86        "integer" | "int" | "int2" | "int4" | "int8" | "smallint" | "bigint" | "serial"
87        | "bigserial" | "real" | "double precision" | "numeric" | "decimal" | "float"
88        | "float4" | "float8" => Some(SqlTypeCategory::Number),
89        "boolean" | "bool" => Some(SqlTypeCategory::Boolean),
90        "timestamp"
91        | "timestamptz"
92        | "date"
93        | "time"
94        | "timetz"
95        | "timestamp without time zone"
96        | "timestamp with time zone" => Some(SqlTypeCategory::Date),
97        "json" | "jsonb" => Some(SqlTypeCategory::Json),
98        "uuid" => Some(SqlTypeCategory::Uuid),
99        "bytea" => Some(SqlTypeCategory::Binary),
100        _ => None,
101    }
102}
103
104fn is_serial(normalized: &str) -> bool {
105    matches!(normalized, "serial" | "bigserial")
106}
107
108fn resolve_sql_type(raw: &str, enum_names: &HashSet<String>) -> SqlType {
109    let trimmed = raw.trim();
110
111    // Array detection
112    if let Some(base_raw) = trimmed.strip_suffix("[]") {
113        let element = resolve_sql_type(base_raw, enum_names);
114        return SqlType {
115            raw: trimmed.to_string(),
116            normalized: trimmed.to_lowercase(),
117            category: element.category.clone(),
118            element_type: Some(Box::new(element)),
119            enum_name: None,
120            enum_values: None,
121            json_shape: None,
122        };
123    }
124
125    let normalized = trimmed.to_lowercase();
126
127    if let Some(cat) = type_category(&normalized) {
128        return SqlType {
129            raw: trimmed.to_string(),
130            normalized,
131            category: cat,
132            element_type: None,
133            enum_name: None,
134            enum_values: None,
135            json_shape: None,
136        };
137    }
138
139    // Check for known enum
140    if enum_names.contains(&normalized) {
141        return SqlType {
142            raw: trimmed.to_string(),
143            normalized: normalized.clone(),
144            category: SqlTypeCategory::Enum,
145            element_type: None,
146            enum_name: Some(normalized),
147            enum_values: None,
148            json_shape: None,
149        };
150    }
151
152    SqlType {
153        raw: trimmed.to_string(),
154        normalized,
155        category: SqlTypeCategory::Unknown,
156        element_type: None,
157        enum_name: None,
158        enum_values: None,
159        json_shape: None,
160    }
161}
162
163// ── Enum parsing ─────────────────────────────────────────────────────────────
164
165fn parse_enum_defs(sql: &str) -> Vec<EnumDef> {
166    let mut enums = Vec::new();
167    for cap in ENUM_DEF_RE.captures_iter(sql) {
168        let name = cap[1].to_lowercase();
169        let values_raw = &cap[2];
170        let values: Vec<String> = ENUM_VAL_RE
171            .captures_iter(values_raw)
172            .map(|v| v[1].to_string())
173            .collect();
174        enums.push(EnumDef { name, values });
175    }
176    enums
177}
178
179// ── Schema parsing (regex-based, matching TS) ────────────────────────────────
180
181const MULTI_WORD_TYPES: &[&str] = &[
182    "character varying",
183    "double precision",
184    "timestamp without time zone",
185    "timestamp with time zone",
186];
187
188struct ParsedColumn {
189    col: ColumnDef,
190    is_pk: bool,
191    is_unique: bool,
192}
193
194fn parse_column_line(line: &str, enum_names: &HashSet<String>) -> Option<ParsedColumn> {
195    let line = line.trim();
196    if line.is_empty() {
197        return None;
198    }
199
200    // Skip constraint lines
201    if CONSTRAINT_RE.is_match(line) {
202        return None;
203    }
204
205    // Extract column name (first word)
206    let name_cap = COL_NAME_RE.captures(line)?;
207    let col_name = name_cap[1].to_lowercase();
208    let after_name = &line[name_cap[0].len()..];
209
210    // Determine the type - check multi-word types first
211    let mut raw_type: Option<String> = None;
212    for mwt in MULTI_WORD_TYPES {
213        if after_name.to_lowercase().starts_with(mwt) {
214            raw_type = Some(mwt.to_string());
215            break;
216        }
217    }
218    if raw_type.is_none()
219        && let Some(cap) = COL_TYPE_RE.captures(after_name)
220    {
221        raw_type = Some(cap[1].to_string());
222    }
223    let raw_type = raw_type.unwrap_or_else(|| "unknown".to_string());
224
225    let rest = &after_name[raw_type.len()..];
226
227    let is_not_null = NOT_NULL_RE.is_match(rest);
228    let has_default_kw = DEFAULT_RE.is_match(rest);
229    let is_serial_type = is_serial(&raw_type.to_lowercase());
230    let is_pk = PK_INLINE_RE.is_match(rest);
231    let is_unique = UNIQUE_RE.is_match(rest);
232
233    let sql_type = resolve_sql_type(&raw_type, enum_names);
234
235    Some(ParsedColumn {
236        col: ColumnDef {
237            name: col_name,
238            alias: None,
239            source_table: None,
240            sql_type,
241            nullable: !is_not_null,
242            has_default: has_default_kw || is_serial_type,
243        },
244        is_pk,
245        is_unique,
246    })
247}
248
249fn parse_schema_tables(sql: &str, enum_names: &HashSet<String>) -> Vec<TableDef> {
250    let mut tables = Vec::new();
251
252    for cap in TABLE_RE.captures_iter(sql) {
253        let table_name = cap[1].to_lowercase();
254        let body = &cap[2];
255
256        let mut columns = Vec::new();
257        let mut primary_key: Vec<String> = Vec::new();
258        let mut unique_constraints: Vec<Vec<String>> = Vec::new();
259
260        // Split body into lines, track comments for annotations
261        let raw_lines: Vec<&str> = body.lines().collect();
262        let mut pending_comment = String::new();
263        let mut non_comment_buf = String::new();
264        let mut comment_map: HashMap<usize, String> = HashMap::new();
265
266        for raw_line in &raw_lines {
267            let trimmed = raw_line.trim();
268            if trimmed.starts_with("--") {
269                if !pending_comment.is_empty() {
270                    pending_comment.push('\n');
271                }
272                pending_comment.push_str(trimmed);
273            } else {
274                let before = split_column_defs(&non_comment_buf)
275                    .iter()
276                    .filter(|d| !d.is_empty())
277                    .count();
278                if !non_comment_buf.is_empty() {
279                    non_comment_buf.push('\n');
280                }
281                non_comment_buf.push_str(raw_line);
282                let after = split_column_defs(&non_comment_buf)
283                    .iter()
284                    .filter(|d| !d.is_empty())
285                    .count();
286
287                if after > before && !pending_comment.is_empty() {
288                    comment_map.insert(before, pending_comment.clone());
289                    pending_comment.clear();
290                } else if after == before {
291                    // Still accumulating same def
292                } else {
293                    pending_comment.clear();
294                }
295            }
296        }
297
298        let lines = split_column_defs(&non_comment_buf);
299
300        for (i, line) in lines.iter().enumerate() {
301            let trimmed = line.trim();
302
303            // Table-level PRIMARY KEY constraint
304            if let Some(pk_cap) = TABLE_PK_RE.captures(trimmed) {
305                for col in pk_cap[1].split(',') {
306                    primary_key.push(col.trim().to_lowercase());
307                }
308                continue;
309            }
310
311            let Some(mut parsed) = parse_column_line(trimmed, enum_names) else {
312                continue;
313            };
314
315            // Apply annotations from comment above this column
316            if let Some(comment) = comment_map.get(&i) {
317                let (_, ann) = extract_annotations(comment);
318                if let Some(values) = ann.enums.get(&parsed.col.name) {
319                    parsed.col.sql_type.category = SqlTypeCategory::Enum;
320                    parsed.col.sql_type.enum_values = Some(values.clone());
321                }
322                if let Some(shape) = ann.json_shapes.get(&parsed.col.name) {
323                    parsed.col.sql_type.json_shape = Some(shape.clone());
324                }
325            }
326
327            if parsed.is_pk {
328                primary_key.push(parsed.col.name.clone());
329            }
330            if parsed.is_unique {
331                unique_constraints.push(vec![parsed.col.name.clone()]);
332            }
333            columns.push(parsed.col);
334        }
335
336        // PK columns are implicitly NOT NULL
337        for col in &mut columns {
338            if primary_key.contains(&col.name) {
339                col.nullable = false;
340                if is_serial(&col.sql_type.normalized) {
341                    col.has_default = true;
342                }
343            }
344        }
345
346        tables.push(TableDef {
347            name: table_name,
348            columns,
349            primary_key,
350            unique_constraints,
351        });
352    }
353
354    tables
355}
356
357// ── Query parsing ────────────────────────────────────────────────────────────
358
359fn extract_param_indices(sql: &str) -> Vec<u32> {
360    let mut indices: HashSet<u32> = HashSet::new();
361    for cap in PARAM_INDEX_RE.captures_iter(sql) {
362        if let Ok(idx) = cap[1].parse::<u32>() {
363            indices.insert(idx);
364        }
365    }
366    let mut sorted: Vec<u32> = indices.into_iter().collect();
367    sorted.sort();
368    sorted
369}
370
371fn infer_param_columns(sql: &str) -> HashMap<u32, String> {
372    let mut result = HashMap::new();
373
374    // INSERT pattern
375    if let Some(cap) = INSERT_RE.captures(sql) {
376        let cols: Vec<String> = cap[1].split(',').map(|s| s.trim().to_lowercase()).collect();
377        let params: Vec<u32> = PARAM_INDEX_RE
378            .captures_iter(&cap[2])
379            .filter_map(|m| m[1].parse().ok())
380            .collect();
381
382        for (i, idx) in params.iter().enumerate() {
383            if i < cols.len() {
384                result.insert(*idx, cols[i].clone());
385            }
386        }
387        return result;
388    }
389
390    // WHERE/SET pattern
391    let sql_keywords: HashSet<&str> = [
392        "not", "and", "or", "where", "set", "when", "then", "else", "case", "between", "exists",
393        "any", "all", "some", "having",
394    ]
395    .into_iter()
396    .collect();
397
398    for cap in WHERE_PARAM_RE.captures_iter(sql) {
399        if let Ok(idx) = cap[4].parse::<u32>() {
400            if cap.get(1).is_some() && cap.get(2).is_some() {
401                // FUNC(col) pattern
402                result.insert(idx, cap[2].to_lowercase());
403            } else if let Some(m) = cap.get(3) {
404                let word = m.as_str().to_lowercase();
405                if !sql_keywords.contains(word.as_str()) {
406                    result.insert(idx, word);
407                }
408            }
409        }
410    }
411
412    result
413}
414
415fn find_from_table<'a>(sql: &str, tables: &'a [TableDef]) -> Option<&'a TableDef> {
416    let cap = FROM_TABLE_RE.captures(sql)?;
417    let table_name = cap[1].to_lowercase();
418    tables.iter().find(|t| t.name == table_name)
419}
420
421fn resolve_returning_columns(sql: &str, table: Option<&TableDef>) -> Option<Vec<ColumnDef>> {
422    let cap = RETURNING_RE.captures(sql)?;
423    let cols_part = cap[1].trim();
424
425    if cols_part == "*" {
426        return Some(table.map(|t| t.columns.clone()).unwrap_or_default());
427    }
428
429    let table = table?;
430    Some(
431        cols_part
432            .split(',')
433            .map(|s| {
434                let name = s.trim().to_lowercase();
435                table
436                    .columns
437                    .iter()
438                    .find(|c| c.name == name)
439                    .cloned()
440                    .unwrap_or_else(|| make_unknown_column(&name))
441            })
442            .collect(),
443    )
444}
445
446fn resolve_return_columns(
447    sql: &str,
448    table: Option<&TableDef>,
449    schema_tables: &[TableDef],
450    source_file: &str,
451) -> Result<Vec<ColumnDef>> {
452    // Check RETURNING clause first
453    if let Some(returning) = resolve_returning_columns(sql, table) {
454        return Ok(returning);
455    }
456
457    if !SELECT_RE.is_match(sql) {
458        return Ok(Vec::new());
459    }
460
461    let Some(cap) = SELECT_COLS_RE.captures(sql) else {
462        return Ok(Vec::new());
463    };
464    let cols_part = cap[1].trim();
465
466    // Multi-table JOIN path: when the outer FROM contains a JOIN, route
467    // each select expression through the shared multi-table resolver.
468    // `has_outer_join` scopes the check to the outer FROM body so that
469    // subqueries with JOINs (e.g. `WHERE id IN (SELECT ... JOIN ...)`)
470    // don't false-trigger.
471    if has_outer_join(sql) {
472        return resolve_multi_table_columns(cols_part, sql, schema_tables, source_file);
473    }
474
475    if cols_part == "*" {
476        return Ok(table.map(|t| t.columns.clone()).unwrap_or_default());
477    }
478
479    let Some(table) = table else {
480        return Ok(Vec::new());
481    };
482
483    let col_names: Vec<&str> = cols_part.split(',').map(|s| s.trim()).collect();
484
485    col_names
486        .iter()
487        .map(|&col_expr| -> Result<ColumnDef> {
488            ensure_supported_select_expr(col_expr, source_file)?;
489            let expr_lower = col_expr.to_lowercase();
490            if let Some(alias_cap) = ALIAS_RE.captures(&expr_lower) {
491                let actual = &alias_cap[1];
492                let alias = alias_cap[2].to_string();
493                Ok(table
494                    .columns
495                    .iter()
496                    .find(|c| c.name == actual)
497                    .map(|c| {
498                        let mut col = c.clone();
499                        col.alias = Some(alias);
500                        col
501                    })
502                    .unwrap_or_else(|| make_unknown_column(actual)))
503            } else {
504                Ok(table
505                    .columns
506                    .iter()
507                    .find(|c| c.name == expr_lower)
508                    .cloned()
509                    .unwrap_or_else(|| make_unknown_column(&expr_lower)))
510            }
511        })
512        .collect()
513}
514
515// ── Public API ───────────────────────────────────────────────────────────────
516
517pub struct PostgresParser;
518
519impl PostgresParser {
520    pub fn new() -> Self {
521        Self
522    }
523}
524
525impl Default for PostgresParser {
526    fn default() -> Self {
527        Self::new()
528    }
529}
530
531impl DatabaseParser for PostgresParser {
532    fn parse_schema(&self, sql: &str) -> Result<(Vec<TableDef>, Vec<EnumDef>)> {
533        let enums = parse_enum_defs(sql);
534        let enum_names: HashSet<String> = enums.iter().map(|e| e.name.clone()).collect();
535        let tables = parse_schema_tables(sql, &enum_names);
536        Ok((tables, enums))
537    }
538
539    fn parse_queries(
540        &self,
541        sql: &str,
542        tables: &[TableDef],
543        enums: &[EnumDef],
544        source_file: &str,
545    ) -> Result<Vec<QueryDef>> {
546        let _ = enums; // available for future use
547        let blocks = split_query_blocks(sql);
548        let mut queries = Vec::new();
549
550        for block in blocks {
551            let table = find_from_table(&block.sql, tables);
552            let param_indices = extract_param_indices(&block.sql);
553            let inferred_cols = infer_param_columns(&block.sql);
554            let params = build_params(&block.comments, table, param_indices, inferred_cols);
555            let returns = resolve_return_columns(&block.sql, table, tables, source_file)?;
556
557            let clean_sql = block
558                .sql
559                .trim_end()
560                .trim_end_matches(';')
561                .trim()
562                .to_string();
563
564            queries.push(QueryDef {
565                name: block.name,
566                command: block.command,
567                sql: clean_sql,
568                params,
569                returns,
570                source_file: source_file.to_string(),
571            });
572        }
573
574        Ok(queries)
575    }
576}
577
578// ── Tests ────────────────────────────────────────────────────────────────────
579
580#[cfg(test)]
581mod tests {
582    use super::*;
583    use crate::ir::{QueryCommand, SqlTypeCategory};
584    use crate::parser::DatabaseParser;
585
586    const SCHEMA_SQL: &str = include_str!("../../../../tests/fixtures/schema.sql");
587    const QUERIES_SQL: &str = include_str!("../../../../tests/fixtures/queries/users.sql");
588
589    #[test]
590    fn parses_enum_type() {
591        let parser = PostgresParser::new();
592        let (_, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
593        assert_eq!(enums.len(), 1);
594        assert_eq!(enums[0].name, "user_status");
595        assert_eq!(enums[0].values, vec!["active", "inactive", "banned"]);
596    }
597
598    #[test]
599    fn parses_users_table() {
600        let parser = PostgresParser::new();
601        let (tables, _) = parser.parse_schema(SCHEMA_SQL).unwrap();
602        let users = tables.iter().find(|t| t.name == "users").unwrap();
603        assert_eq!(users.columns.len(), 7);
604        assert_eq!(users.primary_key, vec!["id"]);
605
606        let id_col = &users.columns[0];
607        assert_eq!(id_col.name, "id");
608        assert_eq!(id_col.sql_type.category, SqlTypeCategory::Number);
609        assert!(id_col.has_default); // SERIAL has implicit default
610        assert!(!id_col.nullable);
611
612        let bio_col = users.columns.iter().find(|c| c.name == "bio").unwrap();
613        assert!(bio_col.nullable);
614
615        let tags_col = users.columns.iter().find(|c| c.name == "tags").unwrap();
616        assert!(tags_col.sql_type.element_type.is_some());
617    }
618
619    #[test]
620    fn parses_posts_table() {
621        let parser = PostgresParser::new();
622        let (tables, _) = parser.parse_schema(SCHEMA_SQL).unwrap();
623        let posts = tables.iter().find(|t| t.name == "posts").unwrap();
624        assert_eq!(posts.columns.len(), 6);
625    }
626
627    #[test]
628    fn parses_get_user_query() {
629        let parser = PostgresParser::new();
630        let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
631        let queries = parser
632            .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
633            .unwrap();
634        let get_user = queries.iter().find(|q| q.name == "GetUser").unwrap();
635        assert_eq!(get_user.command, QueryCommand::One);
636        assert_eq!(get_user.params.len(), 1);
637        assert_eq!(get_user.params[0].name, "id");
638        assert_eq!(get_user.returns.len(), 7); // SELECT * returns all columns
639    }
640
641    #[test]
642    fn parses_list_users_partial_select() {
643        let parser = PostgresParser::new();
644        let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
645        let queries = parser
646            .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
647            .unwrap();
648        let list_users = queries.iter().find(|q| q.name == "ListUsers").unwrap();
649        assert_eq!(list_users.command, QueryCommand::Many);
650        assert_eq!(list_users.returns.len(), 3); // SELECT id, name, email
651    }
652
653    #[test]
654    fn parses_create_user_exec() {
655        let parser = PostgresParser::new();
656        let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
657        let queries = parser
658            .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
659            .unwrap();
660        let create_user = queries.iter().find(|q| q.name == "CreateUser").unwrap();
661        assert_eq!(create_user.command, QueryCommand::Exec);
662        assert_eq!(create_user.params.len(), 3);
663        assert!(create_user.returns.is_empty());
664    }
665
666    #[test]
667    fn parses_delete_user_execresult() {
668        let parser = PostgresParser::new();
669        let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
670        let queries = parser
671            .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
672            .unwrap();
673        let delete_user = queries.iter().find(|q| q.name == "DeleteUser").unwrap();
674        assert_eq!(delete_user.command, QueryCommand::ExecResult);
675    }
676
677    #[test]
678    fn parses_param_overrides() {
679        let parser = PostgresParser::new();
680        let (tables, enums) = parser.parse_schema(SCHEMA_SQL).unwrap();
681        let queries = parser
682            .parse_queries(QUERIES_SQL, &tables, &enums, "queries/users.sql")
683            .unwrap();
684        let date_range = queries
685            .iter()
686            .find(|q| q.name == "ListUsersByDateRange")
687            .unwrap();
688        assert_eq!(date_range.params[0].name, "start_date");
689        assert_eq!(date_range.params[1].name, "end_date");
690    }
691
692    #[test]
693    fn resolve_type_maps_common_types() {
694        let enums = HashSet::new();
695
696        let text = resolve_sql_type("TEXT", &enums);
697        assert_eq!(text.category, SqlTypeCategory::String);
698
699        let int = resolve_sql_type("INTEGER", &enums);
700        assert_eq!(int.category, SqlTypeCategory::Number);
701
702        let bool_t = resolve_sql_type("BOOLEAN", &enums);
703        assert_eq!(bool_t.category, SqlTypeCategory::Boolean);
704
705        let ts = resolve_sql_type("TIMESTAMP", &enums);
706        assert_eq!(ts.category, SqlTypeCategory::Date);
707
708        let json = resolve_sql_type("JSONB", &enums);
709        assert_eq!(json.category, SqlTypeCategory::Json);
710
711        let uuid = resolve_sql_type("UUID", &enums);
712        assert_eq!(uuid.category, SqlTypeCategory::Uuid);
713
714        let bytea = resolve_sql_type("BYTEA", &enums);
715        assert_eq!(bytea.category, SqlTypeCategory::Binary);
716    }
717
718    #[test]
719    fn resolve_type_array() {
720        let enums = HashSet::new();
721        let arr = resolve_sql_type("TEXT[]", &enums);
722        assert_eq!(arr.category, SqlTypeCategory::String);
723        assert!(arr.element_type.is_some());
724        assert_eq!(arr.element_type.unwrap().category, SqlTypeCategory::String);
725    }
726
727    #[test]
728    fn resolve_type_enum() {
729        let mut enums = HashSet::new();
730        enums.insert("user_status".to_string());
731        let t = resolve_sql_type("user_status", &enums);
732        assert_eq!(t.category, SqlTypeCategory::Enum);
733        assert_eq!(t.enum_name, Some("user_status".to_string()));
734    }
735
736    #[test]
737    fn infer_insert_params() {
738        let sql = "INSERT INTO users (name, email, bio) VALUES ($1, $2, $3)";
739        let cols = infer_param_columns(sql);
740        assert_eq!(cols.get(&1), Some(&"name".to_string()));
741        assert_eq!(cols.get(&2), Some(&"email".to_string()));
742        assert_eq!(cols.get(&3), Some(&"bio".to_string()));
743    }
744
745    #[test]
746    fn infer_where_params() {
747        let sql = "SELECT * FROM users WHERE id = $1";
748        let cols = infer_param_columns(sql);
749        assert_eq!(cols.get(&1), Some(&"id".to_string()));
750    }
751
752    #[test]
753    fn split_query_blocks_basic() {
754        let blocks = split_query_blocks(
755            "-- name: GetUser :one\nSELECT * FROM users WHERE id = $1;\n\n-- name: ListUsers :many\nSELECT id, name FROM users;",
756        );
757        assert_eq!(blocks.len(), 2);
758        assert_eq!(blocks[0].name, "GetUser");
759        assert_eq!(blocks[1].name, "ListUsers");
760    }
761
762    #[test]
763    fn resolve_parser_postgres() {
764        let parser = crate::parser::resolve_parser("postgres");
765        assert!(parser.is_ok());
766    }
767
768    #[test]
769    fn resolve_parser_mysql() {
770        let parser = crate::parser::resolve_parser("mysql");
771        assert!(parser.is_ok());
772    }
773
774    #[test]
775    fn resolve_parser_sqlite() {
776        let parser = crate::parser::resolve_parser("sqlite");
777        assert!(parser.is_ok());
778    }
779
780    #[test]
781    fn resolve_parser_unknown() {
782        let parser = crate::parser::resolve_parser("oracle");
783        assert!(parser.is_err());
784    }
785
786    // ── INNER JOIN path tests ────────────────────────────────────────────────
787
788    fn join_schema() -> &'static str {
789        r#"
790        CREATE TABLE users (
791          id INTEGER PRIMARY KEY,
792          name TEXT NOT NULL,
793          org_id INTEGER NOT NULL
794        );
795        CREATE TABLE orgs (
796          id INTEGER PRIMARY KEY,
797          slug TEXT NOT NULL
798        );
799        "#
800    }
801
802    #[test]
803    fn inner_join_resolves_qualified_columns() {
804        let parser = PostgresParser::new();
805        let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
806        let sql = "-- name: GetUserWithOrg :one\nSELECT users.name, orgs.slug FROM users INNER JOIN orgs ON users.org_id = orgs.id WHERE users.id = $1;";
807        let queries = parser.parse_queries(sql, &tables, &enums, "q.sql").unwrap();
808        assert_eq!(queries.len(), 1);
809        let q = &queries[0];
810        assert_eq!(q.returns.len(), 2);
811        assert_eq!(q.returns[0].name, "name");
812        assert_eq!(q.returns[0].source_table.as_deref(), Some("users"));
813        assert_eq!(q.returns[1].name, "slug");
814        assert_eq!(q.returns[1].source_table.as_deref(), Some("orgs"));
815    }
816
817    #[test]
818    fn inner_join_accepts_aliases_and_as() {
819        let parser = PostgresParser::new();
820        let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
821        let sql = "-- name: Listing :many\nSELECT u.id AS user_id, o.slug AS org_slug FROM users u INNER JOIN orgs o ON u.org_id = o.id;";
822        let queries = parser.parse_queries(sql, &tables, &enums, "q.sql").unwrap();
823        let q = &queries[0];
824        assert_eq!(q.returns[0].name, "id");
825        assert_eq!(q.returns[0].alias.as_deref(), Some("user_id"));
826        assert_eq!(q.returns[0].source_table.as_deref(), Some("users"));
827        assert_eq!(q.returns[1].alias.as_deref(), Some("org_slug"));
828        assert_eq!(q.returns[1].source_table.as_deref(), Some("orgs"));
829    }
830
831    #[test]
832    fn inner_join_rejects_select_star() {
833        let parser = PostgresParser::new();
834        let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
835        let sql = "-- name: Everything :many\nSELECT * FROM users INNER JOIN orgs ON users.org_id = orgs.id;";
836        let err = parser
837            .parse_queries(sql, &tables, &enums, "q.sql")
838            .unwrap_err();
839        assert!(
840            err.to_string()
841                .contains("SELECT * across multi-table JOINs")
842        );
843    }
844
845    #[test]
846    fn left_join_rejected_with_v12_pointer() {
847        let parser = PostgresParser::new();
848        let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
849        let sql = "-- name: WithLeft :many\nSELECT users.id FROM users LEFT JOIN orgs ON users.org_id = orgs.id;";
850        let err = parser
851            .parse_queries(sql, &tables, &enums, "q.sql")
852            .unwrap_err();
853        assert!(err.to_string().contains("v1.1 supports INNER JOIN only"));
854    }
855
856    #[test]
857    fn single_table_path_still_rejects_qualified_selects() {
858        // Queries without JOIN go through the existing single-table path,
859        // which still rejects qualified selects via ensure_supported_select_expr.
860        // (PR #32 is the separate effort that relaxes this for single-table queries.)
861        let parser = PostgresParser::new();
862        let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
863        let sql = "-- name: Bad :one\nSELECT users.id FROM users WHERE users.id = $1;";
864        let err = parser
865            .parse_queries(sql, &tables, &enums, "q.sql")
866            .unwrap_err();
867        assert!(
868            err.to_string()
869                .contains("qualified select expressions are not supported")
870        );
871    }
872
873    #[test]
874    fn join_in_subquery_does_not_route_outer_to_multi_table() {
875        // The outer FROM is single-table (`users`). The JOIN lives inside
876        // a subquery. The outer query must use the single-table path — if
877        // we routed to the multi-table resolver, the unqualified outer
878        // `id` select would fail with "requires qualified columns".
879        let parser = PostgresParser::new();
880        let (tables, enums) = parser.parse_schema(join_schema()).unwrap();
881        let sql = "-- name: SubquerySafe :many\nSELECT id FROM users WHERE id IN (SELECT users.id FROM users INNER JOIN orgs ON users.org_id = orgs.id);";
882        let queries = parser.parse_queries(sql, &tables, &enums, "q.sql").unwrap();
883        assert_eq!(queries[0].returns.len(), 1);
884        assert_eq!(queries[0].returns[0].name, "id");
885        assert_eq!(queries[0].returns[0].source_table, None);
886    }
887}