1use std::collections::HashMap;
2
3use regex::Regex;
4
5use crate::ir::{JsonShape, QueryCommand};
6
7pub struct QueryHeader {
10 pub name: String,
11 pub command: QueryCommand,
12}
13
14pub struct Annotations {
15 pub enums: HashMap<String, Vec<String>>,
16 pub json_shapes: HashMap<String, JsonShape>,
17 pub param_overrides: HashMap<u32, String>,
18 pub query_header: Option<QueryHeader>,
19}
20
21impl Annotations {
22 fn new() -> Self {
23 Self {
24 enums: HashMap::new(),
25 json_shapes: HashMap::new(),
26 param_overrides: HashMap::new(),
27 query_header: None,
28 }
29 }
30}
31
32fn header_re() -> Regex {
35 Regex::new(r"--\s*name:\s*(\w+)\s+:(one|many|exec(?:result)?)").unwrap()
36}
37
38fn param_re() -> Regex {
39 Regex::new(r"--\s*@param\s+\$(\d+)\s+(\w+)").unwrap()
40}
41
42fn enum_re() -> Regex {
43 Regex::new(r#"--\s*@enum\s*\(\s*(.*?)\s*\)"#).unwrap()
44}
45
46fn json_re() -> Regex {
47 Regex::new(r"--\s*@json\s*\(\s*([\s\S]+?)\s*\)\s*$").unwrap()
48}
49
50fn find_next_column_name<'a>(lines: &[&'a str], start: usize) -> Option<&'a str> {
54 for line in lines.iter().skip(start) {
55 let t = line.trim();
56 if t.is_empty() || t.starts_with("--") {
57 continue;
58 }
59 return t.split_whitespace().next();
60 }
61 None
62}
63
64#[cfg(test)]
66fn split_top_level(s: &str) -> Vec<&str> {
67 let mut parts = Vec::new();
68 let mut depth = 0usize;
69 let mut start = 0;
70
71 for (i, ch) in s.char_indices() {
72 match ch {
73 '{' | '(' => depth += 1,
74 '}' | ')' => depth = depth.saturating_sub(1),
75 ',' if depth == 0 => {
76 parts.push(&s[start..i]);
77 start = i + 1;
78 }
79 _ => {}
80 }
81 }
82 parts.push(&s[start..]);
83 parts
84}
85
86fn parse_enum_values(inner: &str) -> Vec<String> {
87 let re = Regex::new(r#""([^"]*?)""#).unwrap();
88 re.captures_iter(inner).map(|c| c[1].to_string()).collect()
89}
90
91struct JsonParser<'a> {
94 input: &'a str,
95 pos: usize,
96}
97
98impl<'a> JsonParser<'a> {
99 fn new(input: &'a str) -> Self {
100 Self { input, pos: 0 }
101 }
102
103 fn parse(mut self) -> Result<JsonShape, String> {
104 let shape = self.parse_type()?;
105 self.skip_ws();
106 if self.pos < self.input.len() {
107 return Err(format!(
108 "unexpected trailing content at pos {}: {:?}",
109 self.pos,
110 &self.input[self.pos..].chars().take(10).collect::<String>()
111 ));
112 }
113 Ok(shape)
114 }
115
116 fn parse_type(&mut self) -> Result<JsonShape, String> {
117 self.skip_ws();
118 let mut shape = if self.peek() == Some('{') {
119 self.parse_object()?
120 } else {
121 self.parse_primitive()?
122 };
123
124 self.skip_ws();
126 while self.look_ahead("[]") {
127 self.pos += 2;
128 self.skip_ws();
129 shape = JsonShape::Array {
130 element: Box::new(shape),
131 };
132 }
133
134 if self.peek() == Some('?') {
136 self.pos += 1;
137 shape = JsonShape::Nullable {
138 inner: Box::new(shape),
139 };
140 }
141
142 Ok(shape)
143 }
144
145 fn parse_primitive(&mut self) -> Result<JsonShape, String> {
146 self.skip_ws();
147 if self.match_word("string") {
148 return Ok(JsonShape::String);
149 }
150 if self.match_word("number") {
151 return Ok(JsonShape::Number);
152 }
153 if self.match_word("boolean") {
154 return Ok(JsonShape::Boolean);
155 }
156 Err(format!(
157 "unexpected token at pos {}: {:?}",
158 self.pos,
159 self.input[self.pos..].chars().take(10).collect::<String>()
160 ))
161 }
162
163 fn parse_object(&mut self) -> Result<JsonShape, String> {
164 self.consume('{')?;
165 self.skip_ws();
166 let mut fields = HashMap::new();
167
168 if self.peek() != Some('}') {
169 self.parse_field(&mut fields)?;
170 while self.peek() == Some(',') {
171 self.pos += 1;
172 self.skip_ws();
173 if self.peek() == Some('}') {
174 break; }
176 self.parse_field(&mut fields)?;
177 }
178 }
179
180 self.consume('}')?;
181 Ok(JsonShape::Object { fields })
182 }
183
184 fn parse_field(&mut self, fields: &mut HashMap<String, JsonShape>) -> Result<(), String> {
185 self.skip_ws();
186 let name = self.read_identifier()?;
187 self.skip_ws();
188 self.consume(':')?;
189 self.skip_ws();
190 let shape = self.parse_type()?;
191 self.skip_ws();
192 fields.insert(name, shape);
193 Ok(())
194 }
195
196 fn read_identifier(&mut self) -> Result<String, String> {
197 self.skip_ws();
198 let start = self.pos;
199 while self.pos < self.input.len()
200 && self
201 .input
202 .as_bytes()
203 .get(self.pos)
204 .map(|b| b.is_ascii_alphanumeric() || *b == b'_')
205 .unwrap_or(false)
206 {
207 self.pos += 1;
208 }
209 if self.pos == start {
210 return Err(format!("expected identifier at pos {}", self.pos));
211 }
212 Ok(self.input[start..self.pos].to_string())
213 }
214
215 fn skip_ws(&mut self) {
216 while self.pos < self.input.len() && self.input.as_bytes()[self.pos].is_ascii_whitespace() {
217 self.pos += 1;
218 }
219 }
220
221 fn peek(&mut self) -> Option<char> {
222 self.skip_ws();
223 self.input[self.pos..].chars().next()
224 }
225
226 fn look_ahead(&self, s: &str) -> bool {
227 self.input[self.pos..].starts_with(s)
228 }
229
230 fn match_word(&mut self, word: &str) -> bool {
231 if self.input[self.pos..].starts_with(word) {
232 let after = self.pos + word.len();
233 let next_is_word_char = self
234 .input
235 .as_bytes()
236 .get(after)
237 .map(|b| b.is_ascii_alphanumeric() || *b == b'_')
238 .unwrap_or(false);
239 if !next_is_word_char {
240 self.pos = after;
241 return true;
242 }
243 }
244 false
245 }
246
247 fn consume(&mut self, ch: char) -> Result<(), String> {
248 self.skip_ws();
249 match self.input[self.pos..].chars().next() {
250 Some(c) if c == ch => {
251 self.pos += ch.len_utf8();
252 Ok(())
253 }
254 other => Err(format!(
255 "expected {:?} at pos {}, got {:?}",
256 ch, self.pos, other
257 )),
258 }
259 }
260}
261
262fn parse_json_shape(body: &str) -> Option<JsonShape> {
263 match JsonParser::new(body.trim()).parse() {
264 Ok(shape) => Some(shape),
265 Err(e) => {
266 eprintln!("warning: failed to parse @json annotation: {e}");
267 None
268 }
269 }
270}
271
272pub fn extract_annotations(sql: &str) -> (String, Annotations) {
277 let lines: Vec<&str> = sql.lines().collect();
278 let mut annotations = Annotations::new();
279 let mut kept_lines: Vec<&str> = Vec::new();
280
281 let h_re = header_re();
282 let p_re = param_re();
283 let e_re = enum_re();
284 let j_re = json_re();
285
286 let mut i = 0;
287 while i < lines.len() {
288 let line = lines[i];
289 let trimmed = line.trim();
290
291 if let Some(cap) = h_re.captures(trimmed) {
293 let name = cap[1].to_string();
294 let command = match &cap[2] {
295 "one" => QueryCommand::One,
296 "many" => QueryCommand::Many,
297 "execresult" => QueryCommand::ExecResult,
298 _ => QueryCommand::Exec,
299 };
300 annotations.query_header = Some(QueryHeader { name, command });
301 i += 1;
302 continue;
303 }
304
305 if let Some(cap) = p_re.captures(trimmed) {
307 let idx: u32 = cap[1].parse().unwrap_or(0);
308 let name = cap[2].to_string();
309 annotations.param_overrides.insert(idx, name);
310 i += 1;
311 continue;
312 }
313
314 if let Some(cap) = e_re.captures(trimmed) {
316 let values = parse_enum_values(&cap[1]);
317 if !values.is_empty()
318 && let Some(col) = find_next_column_name(&lines, i + 1)
319 {
320 annotations.enums.insert(col.to_lowercase(), values);
321 }
322 i += 1;
323 continue;
324 }
325
326 if let Some(cap) = j_re.captures(trimmed) {
328 if let Some(shape) = parse_json_shape(&cap[1])
329 && let Some(col) = find_next_column_name(&lines, i + 1)
330 {
331 annotations.json_shapes.insert(col.to_lowercase(), shape);
332 }
333 i += 1;
334 continue;
335 }
336
337 kept_lines.push(line);
338 i += 1;
339 }
340
341 let cleaned = kept_lines.join("\n");
342 (cleaned, annotations)
343}
344
345#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn extract_query_header() {
353 let sql = "-- name: GetUser :one\nSELECT * FROM users WHERE id = $1;";
354 let (cleaned, ann) = extract_annotations(sql);
355 let header = ann.query_header.unwrap();
356 assert_eq!(header.name, "GetUser");
357 assert_eq!(header.command, QueryCommand::One);
358 assert!(!cleaned.contains("-- name:"));
359 }
360
361 #[test]
362 fn extract_enum_annotation() {
363 let sql = "-- @enum(\"draft\", \"published\", \"archived\")\nstatus TEXT NOT NULL";
364 let (_, ann) = extract_annotations(sql);
365 let values = ann.enums.get("status").unwrap();
366 assert_eq!(values, &vec!["draft", "published", "archived"]);
367 }
368
369 #[test]
370 fn extract_json_annotation() {
371 let sql = "-- @json({ theme: string, notifications: boolean })\npreferences JSONB";
372 let (_, ann) = extract_annotations(sql);
373 let shape = ann.json_shapes.get("preferences").unwrap();
374 match shape {
375 JsonShape::Object { fields } => {
376 assert!(fields.contains_key("theme"));
377 assert!(fields.contains_key("notifications"));
378 }
379 _ => panic!("expected Object shape"),
380 }
381 }
382
383 #[test]
384 fn extract_param_override() {
385 let sql = "-- @param $1 start_date\n-- @param $2 end_date\nSELECT * FROM users;";
386 let (_, ann) = extract_annotations(sql);
387 assert_eq!(ann.param_overrides.get(&1), Some(&"start_date".to_string()));
388 assert_eq!(ann.param_overrides.get(&2), Some(&"end_date".to_string()));
389 }
390
391 #[test]
392 fn strips_annotation_lines_from_sql() {
393 let sql = "-- name: GetUser :one\n-- @param $1 user_id\nSELECT * FROM users WHERE id = $1;";
394 let (cleaned, _) = extract_annotations(sql);
395 assert!(!cleaned.contains("@param"));
396 assert!(!cleaned.contains("-- name:"));
397 assert!(cleaned.contains("SELECT"));
398 }
399
400 #[test]
401 fn regular_comments_are_preserved() {
402 let sql = "-- This is a regular comment\nSELECT 1;";
403 let (cleaned, _) = extract_annotations(sql);
404 assert!(cleaned.contains("-- This is a regular comment"));
405 assert!(cleaned.contains("SELECT 1;"));
406 }
407
408 #[test]
409 fn query_command_many() {
410 let sql = "-- name: ListUsers :many\nSELECT * FROM users;";
411 let (_, ann) = extract_annotations(sql);
412 assert_eq!(ann.query_header.unwrap().command, QueryCommand::Many);
413 }
414
415 #[test]
416 fn query_command_exec() {
417 let sql = "-- name: DeleteUser :exec\nDELETE FROM users WHERE id = $1;";
418 let (_, ann) = extract_annotations(sql);
419 assert_eq!(ann.query_header.unwrap().command, QueryCommand::Exec);
420 }
421
422 #[test]
423 fn query_command_execresult() {
424 let sql = "-- name: UpdateUser :execresult\nUPDATE users SET name = $1 WHERE id = $2;";
425 let (_, ann) = extract_annotations(sql);
426 assert_eq!(ann.query_header.unwrap().command, QueryCommand::ExecResult);
427 }
428
429 #[test]
430 fn json_array_shape() {
431 let sql = "-- @json(string[])\ntags TEXT[]";
432 let (_, ann) = extract_annotations(sql);
433 let shape = ann.json_shapes.get("tags").unwrap();
434 match shape {
435 JsonShape::Array { element } => {
436 assert!(matches!(**element, JsonShape::String));
437 }
438 _ => panic!("expected Array shape"),
439 }
440 }
441
442 #[test]
443 fn json_nullable_shape() {
444 let sql = "-- @json(string?)\nnickname TEXT";
445 let (_, ann) = extract_annotations(sql);
446 let shape = ann.json_shapes.get("nickname").unwrap();
447 match shape {
448 JsonShape::Nullable { inner } => {
449 assert!(matches!(**inner, JsonShape::String));
450 }
451 _ => panic!("expected Nullable shape"),
452 }
453 }
454
455 #[test]
456 fn split_top_level_nested() {
457 let parts = split_top_level("a, { b: c, d: e }, f");
458 assert_eq!(parts.len(), 3);
459 assert_eq!(parts[0].trim(), "a");
460 assert_eq!(parts[1].trim(), "{ b: c, d: e }");
461 assert_eq!(parts[2].trim(), "f");
462 }
463
464 #[test]
465 fn empty_sql_no_panic() {
466 let (cleaned, ann) = extract_annotations("");
467 assert_eq!(cleaned, "");
468 assert!(ann.query_header.is_none());
469 assert!(ann.enums.is_empty());
470 }
471}