Skip to main content

sqlcx_core/
annotations.rs

1use std::collections::HashMap;
2
3use regex::Regex;
4
5use crate::ir::{JsonShape, QueryCommand};
6
7// ── Public types ──────────────────────────────────────────────────────────────
8
9pub struct QueryHeader {
10    pub name: String,
11    pub command: QueryCommand,
12}
13
14pub struct Annotations {
15    pub enums: HashMap<String, Vec<String>>,
16    pub json_shapes: HashMap<String, JsonShape>,
17    pub param_overrides: HashMap<u32, String>,
18    pub query_header: Option<QueryHeader>,
19}
20
21impl Annotations {
22    fn new() -> Self {
23        Self {
24            enums: HashMap::new(),
25            json_shapes: HashMap::new(),
26            param_overrides: HashMap::new(),
27            query_header: None,
28        }
29    }
30}
31
32// ── Regex patterns ────────────────────────────────────────────────────────────
33
34fn header_re() -> Regex {
35    Regex::new(r"--\s*name:\s*(\w+)\s+:(one|many|exec(?:result)?)").unwrap()
36}
37
38fn param_re() -> Regex {
39    Regex::new(r"--\s*@param\s+\$(\d+)\s+(\w+)").unwrap()
40}
41
42fn enum_re() -> Regex {
43    Regex::new(r#"--\s*@enum\s*\(\s*(.*?)\s*\)"#).unwrap()
44}
45
46fn json_re() -> Regex {
47    Regex::new(r"--\s*@json\s*\(\s*([\s\S]+?)\s*\)\s*$").unwrap()
48}
49
50// ── Helpers ───────────────────────────────────────────────────────────────────
51
52/// Returns the first word of the first non-empty, non-comment line after `start`.
53fn find_next_column_name<'a>(lines: &[&'a str], start: usize) -> Option<&'a str> {
54    for line in lines.iter().skip(start) {
55        let t = line.trim();
56        if t.is_empty() || t.starts_with("--") {
57            continue;
58        }
59        return t.split_whitespace().next();
60    }
61    None
62}
63
64/// Split a string by commas, ignoring commas inside nested braces/parens.
65#[cfg(test)]
66fn split_top_level(s: &str) -> Vec<&str> {
67    let mut parts = Vec::new();
68    let mut depth = 0usize;
69    let mut start = 0;
70
71    for (i, ch) in s.char_indices() {
72        match ch {
73            '{' | '(' => depth += 1,
74            '}' | ')' => depth = depth.saturating_sub(1),
75            ',' if depth == 0 => {
76                parts.push(&s[start..i]);
77                start = i + 1;
78            }
79            _ => {}
80        }
81    }
82    parts.push(&s[start..]);
83    parts
84}
85
86fn parse_enum_values(inner: &str) -> Vec<String> {
87    let re = Regex::new(r#""([^"]*?)""#).unwrap();
88    re.captures_iter(inner).map(|c| c[1].to_string()).collect()
89}
90
91// ── JSON shape parser ─────────────────────────────────────────────────────────
92
93struct JsonParser<'a> {
94    input: &'a str,
95    pos: usize,
96}
97
98impl<'a> JsonParser<'a> {
99    fn new(input: &'a str) -> Self {
100        Self { input, pos: 0 }
101    }
102
103    fn parse(mut self) -> Result<JsonShape, String> {
104        let shape = self.parse_type()?;
105        self.skip_ws();
106        if self.pos < self.input.len() {
107            return Err(format!(
108                "unexpected trailing content at pos {}: {:?}",
109                self.pos,
110                &self.input[self.pos..].chars().take(10).collect::<String>()
111            ));
112        }
113        Ok(shape)
114    }
115
116    fn parse_type(&mut self) -> Result<JsonShape, String> {
117        self.skip_ws();
118        let mut shape = if self.peek() == Some('{') {
119            self.parse_object()?
120        } else {
121            self.parse_primitive()?
122        };
123
124        // array suffix []
125        self.skip_ws();
126        while self.look_ahead("[]") {
127            self.pos += 2;
128            self.skip_ws();
129            shape = JsonShape::Array {
130                element: Box::new(shape),
131            };
132        }
133
134        // nullable suffix ?
135        if self.peek() == Some('?') {
136            self.pos += 1;
137            shape = JsonShape::Nullable {
138                inner: Box::new(shape),
139            };
140        }
141
142        Ok(shape)
143    }
144
145    fn parse_primitive(&mut self) -> Result<JsonShape, String> {
146        self.skip_ws();
147        if self.match_word("string") {
148            return Ok(JsonShape::String);
149        }
150        if self.match_word("number") {
151            return Ok(JsonShape::Number);
152        }
153        if self.match_word("boolean") {
154            return Ok(JsonShape::Boolean);
155        }
156        Err(format!(
157            "unexpected token at pos {}: {:?}",
158            self.pos,
159            self.input[self.pos..].chars().take(10).collect::<String>()
160        ))
161    }
162
163    fn parse_object(&mut self) -> Result<JsonShape, String> {
164        self.consume('{')?;
165        self.skip_ws();
166        let mut fields = HashMap::new();
167
168        if self.peek() != Some('}') {
169            self.parse_field(&mut fields)?;
170            while self.peek() == Some(',') {
171                self.pos += 1;
172                self.skip_ws();
173                if self.peek() == Some('}') {
174                    break; // trailing comma
175                }
176                self.parse_field(&mut fields)?;
177            }
178        }
179
180        self.consume('}')?;
181        Ok(JsonShape::Object { fields })
182    }
183
184    fn parse_field(&mut self, fields: &mut HashMap<String, JsonShape>) -> Result<(), String> {
185        self.skip_ws();
186        let name = self.read_identifier()?;
187        self.skip_ws();
188        self.consume(':')?;
189        self.skip_ws();
190        let shape = self.parse_type()?;
191        self.skip_ws();
192        fields.insert(name, shape);
193        Ok(())
194    }
195
196    fn read_identifier(&mut self) -> Result<String, String> {
197        self.skip_ws();
198        let start = self.pos;
199        while self.pos < self.input.len()
200            && self
201                .input
202                .as_bytes()
203                .get(self.pos)
204                .map(|b| b.is_ascii_alphanumeric() || *b == b'_')
205                .unwrap_or(false)
206        {
207            self.pos += 1;
208        }
209        if self.pos == start {
210            return Err(format!("expected identifier at pos {}", self.pos));
211        }
212        Ok(self.input[start..self.pos].to_string())
213    }
214
215    fn skip_ws(&mut self) {
216        while self.pos < self.input.len() && self.input.as_bytes()[self.pos].is_ascii_whitespace() {
217            self.pos += 1;
218        }
219    }
220
221    fn peek(&mut self) -> Option<char> {
222        self.skip_ws();
223        self.input[self.pos..].chars().next()
224    }
225
226    fn look_ahead(&self, s: &str) -> bool {
227        self.input[self.pos..].starts_with(s)
228    }
229
230    fn match_word(&mut self, word: &str) -> bool {
231        if self.input[self.pos..].starts_with(word) {
232            let after = self.pos + word.len();
233            let next_is_word_char = self
234                .input
235                .as_bytes()
236                .get(after)
237                .map(|b| b.is_ascii_alphanumeric() || *b == b'_')
238                .unwrap_or(false);
239            if !next_is_word_char {
240                self.pos = after;
241                return true;
242            }
243        }
244        false
245    }
246
247    fn consume(&mut self, ch: char) -> Result<(), String> {
248        self.skip_ws();
249        match self.input[self.pos..].chars().next() {
250            Some(c) if c == ch => {
251                self.pos += ch.len_utf8();
252                Ok(())
253            }
254            other => Err(format!(
255                "expected {:?} at pos {}, got {:?}",
256                ch, self.pos, other
257            )),
258        }
259    }
260}
261
262fn parse_json_shape(body: &str) -> Option<JsonShape> {
263    match JsonParser::new(body.trim()).parse() {
264        Ok(shape) => Some(shape),
265        Err(e) => {
266            eprintln!("warning: failed to parse @json annotation: {e}");
267            None
268        }
269    }
270}
271
272// ── Main extraction function ──────────────────────────────────────────────────
273
274/// Extract annotations from SQL. Returns `(cleaned_sql, annotations)`.
275/// Annotation lines are removed; regular comments are preserved.
276pub fn extract_annotations(sql: &str) -> (String, Annotations) {
277    let lines: Vec<&str> = sql.lines().collect();
278    let mut annotations = Annotations::new();
279    let mut kept_lines: Vec<&str> = Vec::new();
280
281    let h_re = header_re();
282    let p_re = param_re();
283    let e_re = enum_re();
284    let j_re = json_re();
285
286    let mut i = 0;
287    while i < lines.len() {
288        let line = lines[i];
289        let trimmed = line.trim();
290
291        // Query header: -- name: Foo :one
292        if let Some(cap) = h_re.captures(trimmed) {
293            let name = cap[1].to_string();
294            let command = match &cap[2] {
295                "one" => QueryCommand::One,
296                "many" => QueryCommand::Many,
297                "execresult" => QueryCommand::ExecResult,
298                _ => QueryCommand::Exec,
299            };
300            annotations.query_header = Some(QueryHeader { name, command });
301            i += 1;
302            continue;
303        }
304
305        // Param override: -- @param $1 name
306        if let Some(cap) = p_re.captures(trimmed) {
307            let idx: u32 = cap[1].parse().unwrap_or(0);
308            let name = cap[2].to_string();
309            annotations.param_overrides.insert(idx, name);
310            i += 1;
311            continue;
312        }
313
314        // Enum annotation: -- @enum("a", "b")
315        if let Some(cap) = e_re.captures(trimmed) {
316            let values = parse_enum_values(&cap[1]);
317            if !values.is_empty()
318                && let Some(col) = find_next_column_name(&lines, i + 1)
319            {
320                annotations.enums.insert(col.to_lowercase(), values);
321            }
322            i += 1;
323            continue;
324        }
325
326        // JSON annotation: -- @json({ ... })
327        if let Some(cap) = j_re.captures(trimmed) {
328            if let Some(shape) = parse_json_shape(&cap[1])
329                && let Some(col) = find_next_column_name(&lines, i + 1)
330            {
331                annotations.json_shapes.insert(col.to_lowercase(), shape);
332            }
333            i += 1;
334            continue;
335        }
336
337        kept_lines.push(line);
338        i += 1;
339    }
340
341    let cleaned = kept_lines.join("\n");
342    (cleaned, annotations)
343}
344
345// ── Tests ─────────────────────────────────────────────────────────────────────
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn extract_query_header() {
353        let sql = "-- name: GetUser :one\nSELECT * FROM users WHERE id = $1;";
354        let (cleaned, ann) = extract_annotations(sql);
355        let header = ann.query_header.unwrap();
356        assert_eq!(header.name, "GetUser");
357        assert_eq!(header.command, QueryCommand::One);
358        assert!(!cleaned.contains("-- name:"));
359    }
360
361    #[test]
362    fn extract_enum_annotation() {
363        let sql = "-- @enum(\"draft\", \"published\", \"archived\")\nstatus TEXT NOT NULL";
364        let (_, ann) = extract_annotations(sql);
365        let values = ann.enums.get("status").unwrap();
366        assert_eq!(values, &vec!["draft", "published", "archived"]);
367    }
368
369    #[test]
370    fn extract_json_annotation() {
371        let sql = "-- @json({ theme: string, notifications: boolean })\npreferences JSONB";
372        let (_, ann) = extract_annotations(sql);
373        let shape = ann.json_shapes.get("preferences").unwrap();
374        match shape {
375            JsonShape::Object { fields } => {
376                assert!(fields.contains_key("theme"));
377                assert!(fields.contains_key("notifications"));
378            }
379            _ => panic!("expected Object shape"),
380        }
381    }
382
383    #[test]
384    fn extract_param_override() {
385        let sql = "-- @param $1 start_date\n-- @param $2 end_date\nSELECT * FROM users;";
386        let (_, ann) = extract_annotations(sql);
387        assert_eq!(ann.param_overrides.get(&1), Some(&"start_date".to_string()));
388        assert_eq!(ann.param_overrides.get(&2), Some(&"end_date".to_string()));
389    }
390
391    #[test]
392    fn strips_annotation_lines_from_sql() {
393        let sql = "-- name: GetUser :one\n-- @param $1 user_id\nSELECT * FROM users WHERE id = $1;";
394        let (cleaned, _) = extract_annotations(sql);
395        assert!(!cleaned.contains("@param"));
396        assert!(!cleaned.contains("-- name:"));
397        assert!(cleaned.contains("SELECT"));
398    }
399
400    #[test]
401    fn regular_comments_are_preserved() {
402        let sql = "-- This is a regular comment\nSELECT 1;";
403        let (cleaned, _) = extract_annotations(sql);
404        assert!(cleaned.contains("-- This is a regular comment"));
405        assert!(cleaned.contains("SELECT 1;"));
406    }
407
408    #[test]
409    fn query_command_many() {
410        let sql = "-- name: ListUsers :many\nSELECT * FROM users;";
411        let (_, ann) = extract_annotations(sql);
412        assert_eq!(ann.query_header.unwrap().command, QueryCommand::Many);
413    }
414
415    #[test]
416    fn query_command_exec() {
417        let sql = "-- name: DeleteUser :exec\nDELETE FROM users WHERE id = $1;";
418        let (_, ann) = extract_annotations(sql);
419        assert_eq!(ann.query_header.unwrap().command, QueryCommand::Exec);
420    }
421
422    #[test]
423    fn query_command_execresult() {
424        let sql = "-- name: UpdateUser :execresult\nUPDATE users SET name = $1 WHERE id = $2;";
425        let (_, ann) = extract_annotations(sql);
426        assert_eq!(ann.query_header.unwrap().command, QueryCommand::ExecResult);
427    }
428
429    #[test]
430    fn json_array_shape() {
431        let sql = "-- @json(string[])\ntags TEXT[]";
432        let (_, ann) = extract_annotations(sql);
433        let shape = ann.json_shapes.get("tags").unwrap();
434        match shape {
435            JsonShape::Array { element } => {
436                assert!(matches!(**element, JsonShape::String));
437            }
438            _ => panic!("expected Array shape"),
439        }
440    }
441
442    #[test]
443    fn json_nullable_shape() {
444        let sql = "-- @json(string?)\nnickname TEXT";
445        let (_, ann) = extract_annotations(sql);
446        let shape = ann.json_shapes.get("nickname").unwrap();
447        match shape {
448            JsonShape::Nullable { inner } => {
449                assert!(matches!(**inner, JsonShape::String));
450            }
451            _ => panic!("expected Nullable shape"),
452        }
453    }
454
455    #[test]
456    fn split_top_level_nested() {
457        let parts = split_top_level("a, { b: c, d: e }, f");
458        assert_eq!(parts.len(), 3);
459        assert_eq!(parts[0].trim(), "a");
460        assert_eq!(parts[1].trim(), "{ b: c, d: e }");
461        assert_eq!(parts[2].trim(), "f");
462    }
463
464    #[test]
465    fn empty_sql_no_panic() {
466        let (cleaned, ann) = extract_annotations("");
467        assert_eq!(cleaned, "");
468        assert!(ann.query_header.is_none());
469        assert!(ann.enums.is_empty());
470    }
471}