Skip to main content

sql_composer/parser/
template.rs

1//! Top-level template parser that dispatches between macros and literal SQL.
2//!
3//! The key insight of this parser is that SQL is treated as opaque literal text.
4//! Only the `:bind(...)`, `:compose(...)`, `:count(...)`, and `:union(...)`
5//! macros are parsed; everything else passes through unchanged.
6//!
7//! Lines or trailing portions beginning with `#` are template comments and are
8//! silently stripped during parsing — they never appear in composed SQL output.
9
10use winnow::combinator::{alt, repeat, trace};
11use winnow::error::ParserError;
12use winnow::stream::{AsBStr, AsChar, Compare, Stream, StreamIsPartial};
13use winnow::token::{any, literal};
14use winnow::Parser;
15
16use crate::types::Element;
17
18use super::bind::bind;
19use super::command::{command_body, command_kind};
20use super::compose::compose;
21
22/// Parse a single macro invocation after the `:` prefix.
23///
24/// Tries `bind(`, `compose(`, `count(`, or `union(` in order.
25fn macro_invocation<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
26where
27    Input: StreamIsPartial + Stream + Compare<&'i str>,
28    <Input as Stream>::Slice: AsBStr,
29    <Input as Stream>::Token: AsChar + Clone,
30    Error: ParserError<Input>,
31{
32    trace("macro_invocation", move |input: &mut Input| {
33        literal(":").parse_next(input)?;
34
35        alt((
36            literal("bind(").flat_map(|_| bind).map(Element::Bind),
37            literal("compose(")
38                .flat_map(|_| compose)
39                .map(Element::Compose),
40            |input: &mut Input| {
41                let kind = command_kind(input)?;
42                let cmd = command_body(input, kind)?;
43                Ok(Element::Command(cmd))
44            },
45        ))
46        .parse_next(input)
47    })
48    .parse_next(input)
49}
50
51/// Parse literal SQL text: everything up to the next `:` that starts a macro,
52/// or to the end of input.
53///
54/// Accumulates characters one at a time, stopping when we encounter a `:`
55/// followed by a known macro name and `(`.
56fn sql_literal<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
57where
58    Input: StreamIsPartial + Stream + Compare<&'i str>,
59    <Input as Stream>::Slice: AsBStr,
60    <Input as Stream>::Token: AsChar + Clone,
61    Error: ParserError<Input>,
62{
63    trace("sql_literal", move |input: &mut Input| {
64        let mut sql = String::new();
65
66        loop {
67            // Check if we're at a macro start
68            let checkpoint = input.checkpoint();
69            if literal::<_, _, Error>(":").parse_next(input).is_ok() {
70                // Check if this is followed by a known macro name + "("
71                let is_macro = alt((
72                    literal::<_, Input, Error>("bind(").void(),
73                    literal::<_, Input, Error>("compose(").void(),
74                    literal::<_, Input, Error>("count(").void(),
75                    literal::<_, Input, Error>("union(").void(),
76                ))
77                .parse_next(input)
78                .is_ok();
79
80                // Reset to before the ":"
81                input.reset(&checkpoint);
82
83                if is_macro {
84                    break;
85                }
86            } else {
87                input.reset(&checkpoint);
88            }
89
90            // Try to consume one character
91            match any::<_, Error>.parse_next(input) {
92                Ok(c) => {
93                    let ch = c.as_char();
94                    if ch == '#' {
95                        // Comment: skip to end of line (or EOF)
96                        loop {
97                            match any::<_, Error>.parse_next(input) {
98                                Ok(c) if c.clone().as_char() == '\n' => break,
99                                Ok(_) => continue,
100                                Err(_) => break, // EOF
101                            }
102                        }
103                    } else {
104                        sql.push(ch);
105                    }
106                }
107                Err(_) => break, // EOF
108            }
109        }
110
111        if sql.is_empty() {
112            return Err(ParserError::from_input(input));
113        }
114
115        Ok(Element::Sql(sql))
116    })
117    .parse_next(input)
118}
119
120/// Parse a single template element: either a macro invocation or literal SQL.
121fn element<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
122where
123    Input: StreamIsPartial + Stream + Compare<&'i str>,
124    <Input as Stream>::Slice: AsBStr,
125    <Input as Stream>::Token: AsChar + Clone,
126    Error: ParserError<Input>,
127{
128    trace("element", move |input: &mut Input| {
129        alt((macro_invocation, sql_literal)).parse_next(input)
130    })
131    .parse_next(input)
132}
133
134/// Parse a complete template into a sequence of elements.
135///
136/// This is the top-level parser entry point for template content.
137pub fn template<'i, Input, Error>(input: &mut Input) -> Result<Vec<Element>, Error>
138where
139    Input: StreamIsPartial + Stream + Compare<&'i str>,
140    <Input as Stream>::Slice: AsBStr,
141    <Input as Stream>::Token: AsChar + Clone,
142    Error: ParserError<Input>,
143{
144    trace("template", move |input: &mut Input| {
145        let elements: Vec<Element> = repeat(0.., element).parse_next(input)?;
146        Ok(elements)
147    })
148    .parse_next(input)
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154    use crate::types::{Binding, CommandKind, ComposeRef};
155    use std::path::PathBuf;
156    use winnow::error::ContextError;
157
158    type TestInput<'a> = &'a str;
159
160    #[test]
161    fn test_plain_sql() {
162        let mut input: TestInput = "SELECT id, name FROM users";
163        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
164        assert_eq!(result.len(), 1);
165        assert_eq!(result[0], Element::Sql("SELECT id, name FROM users".into()));
166    }
167
168    #[test]
169    fn test_sql_with_bind() {
170        let mut input: TestInput = "SELECT * FROM users WHERE id = :bind(user_id)";
171        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
172        assert_eq!(result.len(), 2);
173        assert_eq!(
174            result[0],
175            Element::Sql("SELECT * FROM users WHERE id = ".into())
176        );
177        assert_eq!(
178            result[1],
179            Element::Bind(Binding {
180                name: "user_id".into(),
181                min_values: None,
182                max_values: None,
183                nullable: false,
184            })
185        );
186    }
187
188    #[test]
189    fn test_sql_with_compose() {
190        let mut input: TestInput = "SELECT COUNT(*) FROM (\n  :compose(templates/get_user.tql)\n)";
191        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
192        assert_eq!(result.len(), 3);
193        assert_eq!(result[0], Element::Sql("SELECT COUNT(*) FROM (\n  ".into()));
194        assert_eq!(
195            result[1],
196            Element::Compose(ComposeRef {
197                path: PathBuf::from("templates/get_user.tql"),
198            })
199        );
200        assert_eq!(result[2], Element::Sql("\n)".into()));
201    }
202
203    #[test]
204    fn test_multiple_binds() {
205        let mut input: TestInput = "WHERE id = :bind(user_id) AND active = :bind(active)";
206        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
207        assert_eq!(result.len(), 4);
208        assert_eq!(result[0], Element::Sql("WHERE id = ".into()));
209        assert_eq!(
210            result[1],
211            Element::Bind(Binding {
212                name: "user_id".into(),
213                min_values: None,
214                max_values: None,
215                nullable: false,
216            })
217        );
218        assert_eq!(result[2], Element::Sql(" AND active = ".into()));
219        assert_eq!(
220            result[3],
221            Element::Bind(Binding {
222                name: "active".into(),
223                min_values: None,
224                max_values: None,
225                nullable: false,
226            })
227        );
228    }
229
230    #[test]
231    fn test_colon_not_a_macro() {
232        let mut input: TestInput = "SELECT '10:30' FROM t";
233        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
234        assert_eq!(result.len(), 1);
235        assert_eq!(result[0], Element::Sql("SELECT '10:30' FROM t".into()));
236    }
237
238    #[test]
239    fn test_command_in_template() {
240        let mut input: TestInput = ":count(templates/get_user.tql)";
241        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
242        assert_eq!(result.len(), 1);
243        match &result[0] {
244            Element::Command(cmd) => {
245                assert_eq!(cmd.kind, CommandKind::Count);
246                assert_eq!(cmd.sources, vec![PathBuf::from("templates/get_user.tql")]);
247            }
248            other => panic!("expected Command, got {:?}", other),
249        }
250    }
251
252    #[test]
253    fn test_full_template() {
254        let mut input: TestInput =
255            "SELECT id, name, email\nFROM users\nWHERE id = :bind(user_id)\n  AND active = :bind(active);";
256        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
257        assert_eq!(result.len(), 5);
258        assert_eq!(
259            result[0],
260            Element::Sql("SELECT id, name, email\nFROM users\nWHERE id = ".into())
261        );
262        assert_eq!(
263            result[1],
264            Element::Bind(Binding {
265                name: "user_id".into(),
266                min_values: None,
267                max_values: None,
268                nullable: false,
269            })
270        );
271        assert_eq!(result[2], Element::Sql("\n  AND active = ".into()));
272        assert_eq!(
273            result[3],
274            Element::Bind(Binding {
275                name: "active".into(),
276                min_values: None,
277                max_values: None,
278                nullable: false,
279            })
280        );
281        assert_eq!(result[4], Element::Sql(";".into()));
282    }
283
284    #[test]
285    fn test_semicolon_after_bind() {
286        let mut input: TestInput = "WHERE id = :bind(user_id);";
287        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
288        assert_eq!(result.len(), 3);
289        assert_eq!(result[0], Element::Sql("WHERE id = ".into()));
290        assert_eq!(
291            result[1],
292            Element::Bind(Binding {
293                name: "user_id".into(),
294                min_values: None,
295                max_values: None,
296                nullable: false,
297            })
298        );
299        assert_eq!(result[2], Element::Sql(";".into()));
300    }
301
302    #[test]
303    fn test_empty_input() {
304        let mut input: TestInput = "";
305        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
306        assert!(result.is_empty());
307    }
308
309    #[test]
310    fn test_comment_standalone_line() {
311        let mut input: TestInput = "# comment\nSELECT 1;";
312        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
313        assert_eq!(result.len(), 1);
314        assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
315    }
316
317    #[test]
318    fn test_comment_inline() {
319        let mut input: TestInput = "SELECT 1; # comment\nSELECT 2;";
320        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
321        assert_eq!(result.len(), 1);
322        assert_eq!(result[0], Element::Sql("SELECT 1; SELECT 2;".into()));
323    }
324
325    #[test]
326    fn test_comment_with_macro_text() {
327        let mut input: TestInput = "# :bind(x)\nSELECT 1;";
328        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
329        assert_eq!(result.len(), 1);
330        assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
331    }
332
333    #[test]
334    fn test_comment_before_macro() {
335        let mut input: TestInput = "# get user\nSELECT * FROM users WHERE id = :bind(id);";
336        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
337        assert_eq!(result.len(), 3);
338        assert_eq!(
339            result[0],
340            Element::Sql("SELECT * FROM users WHERE id = ".into())
341        );
342        assert_eq!(
343            result[1],
344            Element::Bind(Binding {
345                name: "id".into(),
346                min_values: None,
347                max_values: None,
348                nullable: false,
349            })
350        );
351        assert_eq!(result[2], Element::Sql(";".into()));
352    }
353
354    #[test]
355    fn test_only_comments() {
356        let mut input: TestInput = "# just a comment";
357        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
358        assert!(result.is_empty());
359    }
360
361    #[test]
362    fn test_multiple_comment_lines() {
363        let mut input: TestInput = "# line 1\n# line 2\nSELECT 1;";
364        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
365        assert_eq!(result.len(), 1);
366        assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
367    }
368
369    #[test]
370    fn test_comment_at_eof_no_newline() {
371        let mut input: TestInput = "SELECT 1;\n# trailing";
372        let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
373        assert_eq!(result.len(), 1);
374        assert_eq!(result[0], Element::Sql("SELECT 1;\n".into()));
375    }
376}