1use winnow::combinator::{alt, repeat, trace};
11use winnow::error::ParserError;
12use winnow::stream::{AsBStr, AsChar, Compare, Stream, StreamIsPartial};
13use winnow::token::{any, literal};
14use winnow::Parser;
15
16use crate::types::Element;
17
18use super::bind::bind;
19use super::command::{command_body, command_kind};
20use super::compose::compose;
21
22fn macro_invocation<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
26where
27 Input: StreamIsPartial + Stream + Compare<&'i str>,
28 <Input as Stream>::Slice: AsBStr,
29 <Input as Stream>::Token: AsChar + Clone,
30 Error: ParserError<Input>,
31{
32 trace("macro_invocation", move |input: &mut Input| {
33 literal(":").parse_next(input)?;
34
35 alt((
36 literal("bind(").flat_map(|_| bind).map(Element::Bind),
37 literal("compose(")
38 .flat_map(|_| compose)
39 .map(Element::Compose),
40 |input: &mut Input| {
41 let kind = command_kind(input)?;
42 let cmd = command_body(input, kind)?;
43 Ok(Element::Command(cmd))
44 },
45 ))
46 .parse_next(input)
47 })
48 .parse_next(input)
49}
50
51fn sql_literal<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
57where
58 Input: StreamIsPartial + Stream + Compare<&'i str>,
59 <Input as Stream>::Slice: AsBStr,
60 <Input as Stream>::Token: AsChar + Clone,
61 Error: ParserError<Input>,
62{
63 trace("sql_literal", move |input: &mut Input| {
64 let mut sql = String::new();
65
66 loop {
67 let checkpoint = input.checkpoint();
69 if literal::<_, _, Error>(":").parse_next(input).is_ok() {
70 let is_macro = alt((
72 literal::<_, Input, Error>("bind(").void(),
73 literal::<_, Input, Error>("compose(").void(),
74 literal::<_, Input, Error>("count(").void(),
75 literal::<_, Input, Error>("union(").void(),
76 ))
77 .parse_next(input)
78 .is_ok();
79
80 input.reset(&checkpoint);
82
83 if is_macro {
84 break;
85 }
86 } else {
87 input.reset(&checkpoint);
88 }
89
90 match any::<_, Error>.parse_next(input) {
92 Ok(c) => {
93 let ch = c.as_char();
94 if ch == '#' {
95 loop {
97 match any::<_, Error>.parse_next(input) {
98 Ok(c) if c.clone().as_char() == '\n' => break,
99 Ok(_) => continue,
100 Err(_) => break, }
102 }
103 } else {
104 sql.push(ch);
105 }
106 }
107 Err(_) => break, }
109 }
110
111 if sql.is_empty() {
112 return Err(ParserError::from_input(input));
113 }
114
115 Ok(Element::Sql(sql))
116 })
117 .parse_next(input)
118}
119
120fn element<'i, Input, Error>(input: &mut Input) -> Result<Element, Error>
122where
123 Input: StreamIsPartial + Stream + Compare<&'i str>,
124 <Input as Stream>::Slice: AsBStr,
125 <Input as Stream>::Token: AsChar + Clone,
126 Error: ParserError<Input>,
127{
128 trace("element", move |input: &mut Input| {
129 alt((macro_invocation, sql_literal)).parse_next(input)
130 })
131 .parse_next(input)
132}
133
134pub fn template<'i, Input, Error>(input: &mut Input) -> Result<Vec<Element>, Error>
138where
139 Input: StreamIsPartial + Stream + Compare<&'i str>,
140 <Input as Stream>::Slice: AsBStr,
141 <Input as Stream>::Token: AsChar + Clone,
142 Error: ParserError<Input>,
143{
144 trace("template", move |input: &mut Input| {
145 let elements: Vec<Element> = repeat(0.., element).parse_next(input)?;
146 Ok(elements)
147 })
148 .parse_next(input)
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154 use crate::types::{Binding, CommandKind, ComposeRef};
155 use std::path::PathBuf;
156 use winnow::error::ContextError;
157
158 type TestInput<'a> = &'a str;
159
160 #[test]
161 fn test_plain_sql() {
162 let mut input: TestInput = "SELECT id, name FROM users";
163 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
164 assert_eq!(result.len(), 1);
165 assert_eq!(result[0], Element::Sql("SELECT id, name FROM users".into()));
166 }
167
168 #[test]
169 fn test_sql_with_bind() {
170 let mut input: TestInput = "SELECT * FROM users WHERE id = :bind(user_id)";
171 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
172 assert_eq!(result.len(), 2);
173 assert_eq!(
174 result[0],
175 Element::Sql("SELECT * FROM users WHERE id = ".into())
176 );
177 assert_eq!(
178 result[1],
179 Element::Bind(Binding {
180 name: "user_id".into(),
181 min_values: None,
182 max_values: None,
183 nullable: false,
184 })
185 );
186 }
187
188 #[test]
189 fn test_sql_with_compose() {
190 let mut input: TestInput = "SELECT COUNT(*) FROM (\n :compose(templates/get_user.tql)\n)";
191 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
192 assert_eq!(result.len(), 3);
193 assert_eq!(result[0], Element::Sql("SELECT COUNT(*) FROM (\n ".into()));
194 assert_eq!(
195 result[1],
196 Element::Compose(ComposeRef {
197 path: PathBuf::from("templates/get_user.tql"),
198 })
199 );
200 assert_eq!(result[2], Element::Sql("\n)".into()));
201 }
202
203 #[test]
204 fn test_multiple_binds() {
205 let mut input: TestInput = "WHERE id = :bind(user_id) AND active = :bind(active)";
206 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
207 assert_eq!(result.len(), 4);
208 assert_eq!(result[0], Element::Sql("WHERE id = ".into()));
209 assert_eq!(
210 result[1],
211 Element::Bind(Binding {
212 name: "user_id".into(),
213 min_values: None,
214 max_values: None,
215 nullable: false,
216 })
217 );
218 assert_eq!(result[2], Element::Sql(" AND active = ".into()));
219 assert_eq!(
220 result[3],
221 Element::Bind(Binding {
222 name: "active".into(),
223 min_values: None,
224 max_values: None,
225 nullable: false,
226 })
227 );
228 }
229
230 #[test]
231 fn test_colon_not_a_macro() {
232 let mut input: TestInput = "SELECT '10:30' FROM t";
233 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
234 assert_eq!(result.len(), 1);
235 assert_eq!(result[0], Element::Sql("SELECT '10:30' FROM t".into()));
236 }
237
238 #[test]
239 fn test_command_in_template() {
240 let mut input: TestInput = ":count(templates/get_user.tql)";
241 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
242 assert_eq!(result.len(), 1);
243 match &result[0] {
244 Element::Command(cmd) => {
245 assert_eq!(cmd.kind, CommandKind::Count);
246 assert_eq!(cmd.sources, vec![PathBuf::from("templates/get_user.tql")]);
247 }
248 other => panic!("expected Command, got {:?}", other),
249 }
250 }
251
252 #[test]
253 fn test_full_template() {
254 let mut input: TestInput =
255 "SELECT id, name, email\nFROM users\nWHERE id = :bind(user_id)\n AND active = :bind(active);";
256 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
257 assert_eq!(result.len(), 5);
258 assert_eq!(
259 result[0],
260 Element::Sql("SELECT id, name, email\nFROM users\nWHERE id = ".into())
261 );
262 assert_eq!(
263 result[1],
264 Element::Bind(Binding {
265 name: "user_id".into(),
266 min_values: None,
267 max_values: None,
268 nullable: false,
269 })
270 );
271 assert_eq!(result[2], Element::Sql("\n AND active = ".into()));
272 assert_eq!(
273 result[3],
274 Element::Bind(Binding {
275 name: "active".into(),
276 min_values: None,
277 max_values: None,
278 nullable: false,
279 })
280 );
281 assert_eq!(result[4], Element::Sql(";".into()));
282 }
283
284 #[test]
285 fn test_semicolon_after_bind() {
286 let mut input: TestInput = "WHERE id = :bind(user_id);";
287 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
288 assert_eq!(result.len(), 3);
289 assert_eq!(result[0], Element::Sql("WHERE id = ".into()));
290 assert_eq!(
291 result[1],
292 Element::Bind(Binding {
293 name: "user_id".into(),
294 min_values: None,
295 max_values: None,
296 nullable: false,
297 })
298 );
299 assert_eq!(result[2], Element::Sql(";".into()));
300 }
301
302 #[test]
303 fn test_empty_input() {
304 let mut input: TestInput = "";
305 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
306 assert!(result.is_empty());
307 }
308
309 #[test]
310 fn test_comment_standalone_line() {
311 let mut input: TestInput = "# comment\nSELECT 1;";
312 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
313 assert_eq!(result.len(), 1);
314 assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
315 }
316
317 #[test]
318 fn test_comment_inline() {
319 let mut input: TestInput = "SELECT 1; # comment\nSELECT 2;";
320 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
321 assert_eq!(result.len(), 1);
322 assert_eq!(result[0], Element::Sql("SELECT 1; SELECT 2;".into()));
323 }
324
325 #[test]
326 fn test_comment_with_macro_text() {
327 let mut input: TestInput = "# :bind(x)\nSELECT 1;";
328 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
329 assert_eq!(result.len(), 1);
330 assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
331 }
332
333 #[test]
334 fn test_comment_before_macro() {
335 let mut input: TestInput = "# get user\nSELECT * FROM users WHERE id = :bind(id);";
336 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
337 assert_eq!(result.len(), 3);
338 assert_eq!(
339 result[0],
340 Element::Sql("SELECT * FROM users WHERE id = ".into())
341 );
342 assert_eq!(
343 result[1],
344 Element::Bind(Binding {
345 name: "id".into(),
346 min_values: None,
347 max_values: None,
348 nullable: false,
349 })
350 );
351 assert_eq!(result[2], Element::Sql(";".into()));
352 }
353
354 #[test]
355 fn test_only_comments() {
356 let mut input: TestInput = "# just a comment";
357 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
358 assert!(result.is_empty());
359 }
360
361 #[test]
362 fn test_multiple_comment_lines() {
363 let mut input: TestInput = "# line 1\n# line 2\nSELECT 1;";
364 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
365 assert_eq!(result.len(), 1);
366 assert_eq!(result[0], Element::Sql("SELECT 1;".into()));
367 }
368
369 #[test]
370 fn test_comment_at_eof_no_newline() {
371 let mut input: TestInput = "SELECT 1;\n# trailing";
372 let result = template::<_, ContextError>.parse_next(&mut input).unwrap();
373 assert_eq!(result.len(), 1);
374 assert_eq!(result[0], Element::Sql("SELECT 1;\n".into()));
375 }
376}