Skip to main content

sql_composer/parser/
command.rs

1//! Parser for `:count(...)` and `:union(...)` command macros.
2
3use std::path::PathBuf;
4
5use winnow::combinator::{alt, opt, separated, trace};
6use winnow::error::ParserError;
7use winnow::stream::{AsBStr, AsChar, Compare, Stream, StreamIsPartial};
8use winnow::token::{literal, take_while};
9use winnow::Parser;
10
11use crate::types::{Command, CommandKind};
12
13/// Parse optional whitespace within command parentheses.
14fn ws<'i, Input, Error>(input: &mut Input) -> Result<(), Error>
15where
16    Input: StreamIsPartial + Stream + Compare<&'i str>,
17    <Input as Stream>::Slice: AsBStr,
18    <Input as Stream>::Token: AsChar + Clone,
19    Error: ParserError<Input>,
20{
21    take_while(0.., |c: <Input as Stream>::Token| {
22        let ch = c.as_char();
23        ch == ' ' || ch == '\t' || ch == '\n' || ch == '\r'
24    })
25    .void()
26    .parse_next(input)
27}
28
29/// Parse the DISTINCT keyword.
30fn distinct<'i, Input, Error>(input: &mut Input) -> Result<bool, Error>
31where
32    Input: StreamIsPartial + Stream + Compare<&'i str>,
33    <Input as Stream>::Slice: AsBStr,
34    <Input as Stream>::Token: AsChar + Clone,
35    Error: ParserError<Input>,
36{
37    trace("distinct", move |input: &mut Input| {
38        literal("DISTINCT").parse_next(input)?;
39        ws(input)?;
40        Ok(true)
41    })
42    .parse_next(input)
43}
44
45/// Parse the ALL keyword.
46fn all_kw<'i, Input, Error>(input: &mut Input) -> Result<bool, Error>
47where
48    Input: StreamIsPartial + Stream + Compare<&'i str>,
49    <Input as Stream>::Slice: AsBStr,
50    <Input as Stream>::Token: AsChar + Clone,
51    Error: ParserError<Input>,
52{
53    trace("all", move |input: &mut Input| {
54        literal("ALL").parse_next(input)?;
55        ws(input)?;
56        Ok(true)
57    })
58    .parse_next(input)
59}
60
61/// Parse a column name: alphanumeric + underscore + dot.
62fn column_name<'i, Input, Error>(input: &mut Input) -> Result<String, Error>
63where
64    Input: StreamIsPartial + Stream + Compare<&'i str>,
65    <Input as Stream>::Slice: AsBStr,
66    <Input as Stream>::Token: AsChar + Clone,
67    Error: ParserError<Input>,
68{
69    let name = take_while(1.., |c: <Input as Stream>::Token| {
70        let ch = c.as_char();
71        ch.is_alphanumeric() || ch == '_' || ch == '.'
72    })
73    .parse_next(input)?;
74    let name = String::from_utf8_lossy(name.as_bstr()).to_string();
75    Ok(name)
76}
77
78/// Parse a comma separator with optional surrounding whitespace.
79fn comma_sep<'i, Input, Error>(input: &mut Input) -> Result<(), Error>
80where
81    Input: StreamIsPartial + Stream + Compare<&'i str>,
82    <Input as Stream>::Slice: AsBStr,
83    <Input as Stream>::Token: AsChar + Clone,
84    Error: ParserError<Input>,
85{
86    ws(input)?;
87    literal(",").parse_next(input)?;
88    ws(input)?;
89    Ok(())
90}
91
92/// Parse a column list followed by `OF`: `col1, col2, col3 OF`.
93fn columns_of<'i, Input, Error>(input: &mut Input) -> Result<Vec<String>, Error>
94where
95    Input: StreamIsPartial + Stream + Compare<&'i str>,
96    <Input as Stream>::Slice: AsBStr,
97    <Input as Stream>::Token: AsChar + Clone,
98    Error: ParserError<Input>,
99{
100    trace("columns_of", move |input: &mut Input| {
101        let cols: Vec<String> = separated(1.., column_name, comma_sep).parse_next(input)?;
102        ws(input)?;
103        literal("OF").parse_next(input)?;
104        ws(input)?;
105        Ok(cols)
106    })
107    .parse_next(input)
108}
109
110/// Parse a source path: one or more non-whitespace, non-`)`, non-`,` characters.
111fn source_path<'i, Input, Error>(input: &mut Input) -> Result<PathBuf, Error>
112where
113    Input: StreamIsPartial + Stream + Compare<&'i str>,
114    <Input as Stream>::Slice: AsBStr,
115    <Input as Stream>::Token: AsChar + Clone,
116    Error: ParserError<Input>,
117{
118    let path_str = take_while(1.., |c: <Input as Stream>::Token| {
119        let ch = c.as_char();
120        ch != ')' && ch != ',' && ch != ' ' && ch != '\t' && ch != '\n' && ch != '\r'
121    })
122    .parse_next(input)?;
123    let path_str = String::from_utf8_lossy(path_str.as_bstr()).to_string();
124    Ok(PathBuf::from(path_str))
125}
126
127/// Parse the command kind from the prefix keyword.
128///
129/// This parses `count(` or `union(` and returns the command kind.
130pub fn command_kind<'i, Input, Error>(input: &mut Input) -> Result<CommandKind, Error>
131where
132    Input: StreamIsPartial + Stream + Compare<&'i str>,
133    <Input as Stream>::Slice: AsBStr,
134    <Input as Stream>::Token: AsChar + Clone,
135    Error: ParserError<Input>,
136{
137    trace("command_kind", move |input: &mut Input| {
138        alt((
139            literal("count(").map(|_| CommandKind::Count),
140            literal("union(").map(|_| CommandKind::Union),
141        ))
142        .parse_next(input)
143    })
144    .parse_next(input)
145}
146
147/// Parse the body of a command after `count(` or `union(` has been consumed.
148///
149/// Grammar: `[DISTINCT] [ALL] [columns OF] source1[, source2, ...] )`
150pub fn command_body<'i, Input, Error>(
151    input: &mut Input,
152    kind: CommandKind,
153) -> Result<Command, Error>
154where
155    Input: StreamIsPartial + Stream + Compare<&'i str>,
156    <Input as Stream>::Slice: AsBStr,
157    <Input as Stream>::Token: AsChar + Clone,
158    Error: ParserError<Input>,
159{
160    trace("command_body", move |input: &mut Input| {
161        ws(input)?;
162        let is_distinct = opt(distinct).parse_next(input)?.unwrap_or(false);
163        let is_all = opt(all_kw).parse_next(input)?.unwrap_or(false);
164        let columns = opt(columns_of).parse_next(input)?;
165        let sources: Vec<PathBuf> = separated(1.., source_path, comma_sep).parse_next(input)?;
166        ws(input)?;
167        literal(")").parse_next(input)?;
168
169        Ok(Command {
170            kind,
171            distinct: is_distinct,
172            all: is_all,
173            columns,
174            sources,
175        })
176    })
177    .parse_next(input)
178}
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183    use winnow::error::ContextError;
184
185    type TestInput<'a> = &'a str;
186
187    #[test]
188    fn test_command_kind_count() {
189        let mut input: TestInput = "count(";
190        let result = command_kind::<_, ContextError>
191            .parse_next(&mut input)
192            .unwrap();
193        assert_eq!(result, CommandKind::Count);
194    }
195
196    #[test]
197    fn test_command_kind_union() {
198        let mut input: TestInput = "union(";
199        let result = command_kind::<_, ContextError>
200            .parse_next(&mut input)
201            .unwrap();
202        assert_eq!(result, CommandKind::Union);
203    }
204
205    #[test]
206    fn test_command_simple_count() {
207        let mut input: TestInput = "count(templates/get_user.tql)";
208        let kind = command_kind::<_, ContextError>
209            .parse_next(&mut input)
210            .unwrap();
211        let result = command_body::<_, ContextError>(&mut input, kind).unwrap();
212        assert_eq!(result.kind, CommandKind::Count);
213        assert!(!result.distinct);
214        assert!(!result.all);
215        assert_eq!(result.columns, None);
216        assert_eq!(
217            result.sources,
218            vec![PathBuf::from("templates/get_user.tql")]
219        );
220    }
221
222    #[test]
223    fn test_command_union_multiple_sources() {
224        let mut input: TestInput = "union(a.tql, b.tql, c.tql)";
225        let kind = command_kind::<_, ContextError>
226            .parse_next(&mut input)
227            .unwrap();
228        let result = command_body::<_, ContextError>(&mut input, kind).unwrap();
229        assert_eq!(result.kind, CommandKind::Union);
230        assert_eq!(
231            result.sources,
232            vec![
233                PathBuf::from("a.tql"),
234                PathBuf::from("b.tql"),
235                PathBuf::from("c.tql")
236            ]
237        );
238    }
239
240    #[test]
241    fn test_command_with_distinct() {
242        let mut input: TestInput = "union(DISTINCT a.tql, b.tql)";
243        let kind = command_kind::<_, ContextError>
244            .parse_next(&mut input)
245            .unwrap();
246        let result = command_body::<_, ContextError>(&mut input, kind).unwrap();
247        assert!(result.distinct);
248        assert!(!result.all);
249    }
250
251    #[test]
252    fn test_command_with_all() {
253        let mut input: TestInput = "union(ALL a.tql, b.tql)";
254        let kind = command_kind::<_, ContextError>
255            .parse_next(&mut input)
256            .unwrap();
257        let result = command_body::<_, ContextError>(&mut input, kind).unwrap();
258        assert!(!result.distinct);
259        assert!(result.all);
260    }
261
262    #[test]
263    fn test_command_with_columns() {
264        let mut input: TestInput = "count(id, name OF templates/get_user.tql)";
265        let kind = command_kind::<_, ContextError>
266            .parse_next(&mut input)
267            .unwrap();
268        let result = command_body::<_, ContextError>(&mut input, kind).unwrap();
269        assert_eq!(
270            result.columns,
271            Some(vec!["id".to_string(), "name".to_string()])
272        );
273        assert_eq!(
274            result.sources,
275            vec![PathBuf::from("templates/get_user.tql")]
276        );
277    }
278}