1use sqlparser::parser::Parser;
2
3use crate::dialect::SqlDialect;
4use crate::errors::ScytheError;
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum QueryCommand {
8 One,
9 Many,
10 Exec,
11 ExecResult,
12 ExecRows,
13 Batch,
14 Grouped,
15}
16
17impl std::fmt::Display for QueryCommand {
18 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19 match self {
20 QueryCommand::One => write!(f, "one"),
21 QueryCommand::Many => write!(f, "many"),
22 QueryCommand::Exec => write!(f, "exec"),
23 QueryCommand::ExecResult => write!(f, "exec_result"),
24 QueryCommand::ExecRows => write!(f, "exec_rows"),
25 QueryCommand::Batch => write!(f, "batch"),
26 QueryCommand::Grouped => write!(f, "grouped"),
27 }
28 }
29}
30
31impl QueryCommand {
32 fn from_str(s: &str) -> Result<Self, ScytheError> {
33 match s {
34 "one" => Ok(QueryCommand::One),
35 "many" => Ok(QueryCommand::Many),
36 "exec" => Ok(QueryCommand::Exec),
37 "exec_result" => Ok(QueryCommand::ExecResult),
38 "exec_rows" => Ok(QueryCommand::ExecRows),
39 "batch" => Ok(QueryCommand::Batch),
40 "grouped" => Ok(QueryCommand::Grouped),
41 other => Err(ScytheError::invalid_annotation(format!(
42 "invalid @returns value: {other}"
43 ))),
44 }
45 }
46}
47
48#[derive(Debug, Clone, PartialEq, Eq)]
49pub struct ParamDoc {
50 pub name: String,
51 pub description: String,
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub struct JsonMapping {
56 pub column: String,
57 pub rust_type: String,
58}
59
60#[derive(Debug, Clone, PartialEq, Eq)]
61pub struct Annotations {
62 pub name: String,
63 pub command: QueryCommand,
64 pub param_docs: Vec<ParamDoc>,
65 pub nullable_overrides: Vec<String>,
66 pub nonnull_overrides: Vec<String>,
67 pub json_mappings: Vec<JsonMapping>,
68 pub deprecated: Option<String>,
69 pub optional_params: Vec<String>,
70 pub group_by: Option<String>,
71}
72
73#[derive(Debug)]
74pub struct Query {
75 pub name: String,
76 pub command: QueryCommand,
77 pub sql: String,
78 pub stmt: sqlparser::ast::Statement,
79 pub annotations: Annotations,
80}
81
82pub fn parse_query(query_sql: &str) -> Result<Query, ScytheError> {
84 parse_query_with_dialect(query_sql, &SqlDialect::PostgreSQL)
85}
86
87pub fn parse_query_with_dialect(
89 query_sql: &str,
90 dialect: &SqlDialect,
91) -> Result<Query, ScytheError> {
92 let mut name: Option<String> = None;
93 let mut command: Option<QueryCommand> = None;
94 let mut param_docs = Vec::new();
95 let mut nullable_overrides = Vec::new();
96 let mut nonnull_overrides = Vec::new();
97 let mut json_mappings = Vec::new();
98 let mut deprecated: Option<String> = None;
99 let mut optional_params = Vec::new();
100 let mut group_by: Option<String> = None;
101
102 let mut sql_lines = Vec::new();
103
104 for line in query_sql.lines() {
105 let trimmed = line.trim();
106
107 let annotation_body = if let Some(rest) = trimmed.strip_prefix("--") {
109 let rest = rest.trim_start();
110 rest.strip_prefix('@')
111 } else {
112 None
113 };
114
115 if let Some(body) = annotation_body {
116 let (keyword, value) = match body.find(|c: char| c.is_whitespace()) {
118 Some(pos) => (&body[..pos], body[pos..].trim()),
119 None => (body, ""),
120 };
121
122 match keyword.to_ascii_lowercase().as_str() {
123 "name" => {
124 name = Some(value.to_string());
125 }
126 "returns" => {
127 let cmd_str = value.strip_prefix(':').unwrap_or(value);
128 command = Some(QueryCommand::from_str(cmd_str)?);
129 }
130 "param" => {
131 if let Some(colon_pos) = value.find(':') {
133 let param_name = value[..colon_pos].trim().to_string();
134 let description = value[colon_pos + 1..].trim().to_string();
135 param_docs.push(ParamDoc {
136 name: param_name,
137 description,
138 });
139 } else {
140 param_docs.push(ParamDoc {
141 name: value.to_string(),
142 description: String::new(),
143 });
144 }
145 }
146 "nullable" => {
147 for col in value.split(',') {
148 let col = col.trim();
149 if !col.is_empty() {
150 nullable_overrides.push(col.to_string());
151 }
152 }
153 }
154 "nonnull" => {
155 for col in value.split(',') {
156 let col = col.trim();
157 if !col.is_empty() {
158 nonnull_overrides.push(col.to_string());
159 }
160 }
161 }
162 "json" => {
163 if let Some(eq_pos) = value.find('=') {
165 let column = value[..eq_pos].trim().to_string();
166 let rust_type = value[eq_pos + 1..].trim().to_string();
167 json_mappings.push(JsonMapping { column, rust_type });
168 }
169 }
170 "deprecated" => {
171 deprecated = Some(value.to_string());
172 }
173 "group_by" => {
174 group_by = Some(value.to_string());
175 }
176 "optional" => {
177 for param in value.split(',') {
178 let param = param.trim();
179 if !param.is_empty() {
180 optional_params.push(param.to_string());
181 }
182 }
183 }
184 _ => {
185 }
187 }
188 } else {
189 sql_lines.push(line);
190 }
191 }
192
193 let name = name.ok_or_else(|| ScytheError::missing_annotation("name"))?;
194 let command = command.ok_or_else(|| ScytheError::missing_annotation("returns"))?;
195
196 if command == QueryCommand::Grouped && group_by.is_none() {
197 return Err(ScytheError::invalid_annotation(
198 "@returns :grouped requires a @group_by annotation (e.g. @group_by users.id)",
199 ));
200 }
201
202 let sql = sql_lines.join("\n").trim().to_string();
203
204 if sql.is_empty() {
205 return Err(ScytheError::syntax("empty SQL body"));
206 }
207
208 let parser_dialect = dialect.to_sqlparser_dialect();
209 let statements = Parser::parse_sql(parser_dialect.as_ref(), &sql)
210 .map_err(|e| ScytheError::syntax(format!("syntax error: {}", e)))?;
211
212 if statements.len() != 1 {
213 let non_empty: Vec<_> = statements
216 .into_iter()
217 .filter(|s| {
218 !matches!(s, sqlparser::ast::Statement::Flush { .. }) && format!("{s}") != ""
219 })
220 .collect();
221 if non_empty.len() != 1 {
222 return Err(ScytheError::syntax("expected exactly one SQL statement"));
223 }
224 let stmt = non_empty
225 .into_iter()
226 .next()
227 .expect("filtered to exactly one statement");
228 let annotations = Annotations {
229 name: name.clone(),
230 command: command.clone(),
231 param_docs,
232 nullable_overrides,
233 nonnull_overrides,
234 json_mappings,
235 deprecated,
236 optional_params,
237 group_by: group_by.clone(),
238 };
239 return Ok(Query {
240 name,
241 command,
242 sql,
243 stmt,
244 annotations,
245 });
246 }
247
248 let stmt = statements
249 .into_iter()
250 .next()
251 .expect("filtered to exactly one statement");
252
253 let annotations = Annotations {
254 name: name.clone(),
255 command: command.clone(),
256 param_docs,
257 nullable_overrides,
258 nonnull_overrides,
259 json_mappings,
260 deprecated,
261 optional_params,
262 group_by,
263 };
264
265 Ok(Query {
266 name,
267 command,
268 sql,
269 stmt,
270 annotations,
271 })
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::errors::ErrorCode;
278
279 fn parse(sql: &str) -> Result<Query, ScytheError> {
280 parse_query(sql)
281 }
282
283 #[test]
284 fn test_basic_parse() {
285 let input = "-- @name GetUsers\n-- @returns :many\nSELECT * FROM users;";
286 let q = parse(input).unwrap();
287 assert_eq!(q.name, "GetUsers");
288 assert_eq!(q.command, QueryCommand::Many);
289 assert!(q.sql.contains("SELECT"));
290 }
291
292 #[test]
293 fn test_all_command_types() {
294 let cases = vec![
295 (":one", QueryCommand::One),
296 (":many", QueryCommand::Many),
297 (":exec", QueryCommand::Exec),
298 (":exec_result", QueryCommand::ExecResult),
299 (":exec_rows", QueryCommand::ExecRows),
300 ];
301 for (tag, expected) in cases {
302 let input = format!("-- @name Q\n-- @returns {}\nSELECT 1", tag);
303 let q = parse(&input).unwrap();
304 assert_eq!(q.command, expected, "failed for {}", tag);
305 }
306 }
307
308 #[test]
309 fn test_case_insensitive_keywords() {
310 let input = "-- @Name GetUsers\n-- @RETURNS :many\nSELECT 1";
311 let q = parse(input).unwrap();
312 assert_eq!(q.name, "GetUsers");
313 assert_eq!(q.command, QueryCommand::Many);
314 }
315
316 #[test]
317 fn test_missing_name_errors() {
318 let input = "-- @returns :many\nSELECT 1";
319 let err = parse(input).unwrap_err();
320 assert_eq!(err.code, ErrorCode::MissingAnnotation);
321 assert!(err.message.contains("name"));
322 }
323
324 #[test]
325 fn test_missing_returns_errors() {
326 let input = "-- @name Foo\nSELECT 1";
327 let err = parse(input).unwrap_err();
328 assert_eq!(err.code, ErrorCode::MissingAnnotation);
329 assert!(err.message.contains("returns"));
330 }
331
332 #[test]
333 fn test_invalid_returns_value() {
334 let input = "-- @name Foo\n-- @returns :invalid\nSELECT 1";
335 let err = parse(input).unwrap_err();
336 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
337 }
338
339 #[test]
340 fn test_empty_name_value() {
341 let input = "-- @name\n-- @returns :one\nSELECT 1";
343 let q = parse(input).unwrap();
344 assert_eq!(q.name, "");
345 }
346
347 #[test]
348 fn test_param_annotation() {
349 let input = "-- @name Foo\n-- @returns :one\n-- @param id: the user ID\nSELECT 1";
350 let q = parse(input).unwrap();
351 assert_eq!(q.annotations.param_docs.len(), 1);
352 assert_eq!(q.annotations.param_docs[0].name, "id");
353 assert_eq!(q.annotations.param_docs[0].description, "the user ID");
354 }
355
356 #[test]
357 fn test_param_no_description() {
358 let input = "-- @name Foo\n-- @returns :one\n-- @param id\nSELECT 1";
359 let q = parse(input).unwrap();
360 assert_eq!(q.annotations.param_docs.len(), 1);
361 assert_eq!(q.annotations.param_docs[0].name, "id");
362 assert_eq!(q.annotations.param_docs[0].description, "");
363 }
364
365 #[test]
366 fn test_nullable_annotation() {
367 let input = "-- @name Foo\n-- @returns :one\n-- @nullable col1, col2\nSELECT 1";
368 let q = parse(input).unwrap();
369 assert_eq!(q.annotations.nullable_overrides, vec!["col1", "col2"]);
370 }
371
372 #[test]
373 fn test_nonnull_annotation() {
374 let input = "-- @name Foo\n-- @returns :one\n-- @nonnull col1\nSELECT 1";
375 let q = parse(input).unwrap();
376 assert_eq!(q.annotations.nonnull_overrides, vec!["col1"]);
377 }
378
379 #[test]
380 fn test_json_annotation() {
381 let input = "-- @name Foo\n-- @returns :one\n-- @json data = EventData\nSELECT 1";
382 let q = parse(input).unwrap();
383 assert_eq!(q.annotations.json_mappings.len(), 1);
384 assert_eq!(q.annotations.json_mappings[0].column, "data");
385 assert_eq!(q.annotations.json_mappings[0].rust_type, "EventData");
386 }
387
388 #[test]
389 fn test_deprecated_annotation() {
390 let input = "-- @name Foo\n-- @returns :one\n-- @deprecated Use V2\nSELECT 1";
391 let q = parse(input).unwrap();
392 assert_eq!(q.annotations.deprecated, Some("Use V2".to_string()));
393 }
394
395 #[test]
396 fn test_sql_syntax_error() {
397 let input = "-- @name Foo\n-- @returns :one\nSELCT * FROM users";
398 let err = parse(input).unwrap_err();
399 assert_eq!(err.code, ErrorCode::SyntaxError);
400 }
401
402 #[test]
403 fn test_trailing_semicolon() {
404 let input = "-- @name Foo\n-- @returns :one\nSELECT 1;";
405 let q = parse(input).unwrap();
406 assert_eq!(q.name, "Foo");
407 }
408
409 #[test]
410 fn test_multiple_statements_error() {
411 let input = "-- @name Foo\n-- @returns :one\nSELECT 1; SELECT 2;";
412 let err = parse(input).unwrap_err();
413 assert_eq!(err.code, ErrorCode::SyntaxError);
414 }
415
416 #[test]
417 fn test_sql_preserved_without_annotations() {
418 let input = "-- @name Foo\n-- @returns :one\nSELECT id, name FROM users WHERE id = $1";
419 let q = parse(input).unwrap();
420 assert_eq!(q.sql, "SELECT id, name FROM users WHERE id = $1");
421 }
422
423 #[test]
424 fn test_returns_without_colon_prefix() {
425 let input = "-- @name Foo\n-- @returns many\nSELECT 1";
426 let q = parse(input).unwrap();
427 assert_eq!(q.command, QueryCommand::Many);
428 }
429
430 #[test]
431 fn test_batch_command() {
432 let input = "-- @name Foo\n-- @returns :batch\nSELECT 1";
433 let q = parse(input).unwrap();
434 assert_eq!(q.command, QueryCommand::Batch);
435 }
436
437 #[test]
438 fn test_grouped_command_with_group_by() {
439 let input = "-- @name GetUsersWithOrders\n-- @returns :grouped\n-- @group_by users.id\nSELECT u.id, u.name FROM users u JOIN orders o ON o.user_id = u.id";
440 let q = parse(input).unwrap();
441 assert_eq!(q.command, QueryCommand::Grouped);
442 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
443 }
444
445 #[test]
446 fn test_grouped_command_without_group_by_errors() {
447 let input = "-- @name Foo\n-- @returns :grouped\nSELECT 1";
448 let err = parse(input).unwrap_err();
449 assert_eq!(err.code, ErrorCode::InvalidAnnotation);
450 assert!(err.message.contains("@group_by"));
451 }
452
453 #[test]
454 fn test_group_by_without_grouped_is_ignored() {
455 let input = "-- @name Foo\n-- @returns :many\n-- @group_by users.id\nSELECT 1";
456 let q = parse(input).unwrap();
457 assert_eq!(q.command, QueryCommand::Many);
458 assert_eq!(q.annotations.group_by, Some("users.id".to_string()));
459 }
460}