Skip to main content

shape_ast/parser/queries/
joins.rs

1//! JOIN clause parser
2//!
3//! Parses JOIN clauses for ANALYZE queries:
4//! - `INNER JOIN source ON condition`
5//! - `LEFT JOIN source USING (col1, col2)`
6//! - `JOIN source WITHIN 100ms` (temporal join)
7
8use crate::ast::{JoinClause, JoinCondition, JoinSource, JoinType};
9use crate::data::Timeframe;
10use crate::error::{Result, ShapeError};
11use crate::parser::{Rule, expressions, pair_location};
12use pest::iterators::Pair;
13
14/// Parse a JOIN clause
15///
16/// Grammar: `join_type? "join" join_source join_condition?`
17pub fn parse_join_clause(pair: Pair<Rule>) -> Result<JoinClause> {
18    let pair_loc = pair_location(&pair);
19    let mut join_type = JoinType::Inner; // Default
20    let mut join_source = None;
21    let mut join_condition = JoinCondition::Natural; // Default for cross joins
22
23    for inner in pair.into_inner() {
24        match inner.as_rule() {
25            Rule::join_type => {
26                join_type = parse_join_type(inner)?;
27            }
28            Rule::join_source => {
29                join_source = Some(parse_join_source(inner)?);
30            }
31            Rule::join_condition => {
32                join_condition = parse_join_condition(inner)?;
33            }
34            _ => {}
35        }
36    }
37
38    let right = join_source.ok_or_else(|| ShapeError::ParseError {
39        message: "JOIN clause requires a source (table/symbol name or subquery)".to_string(),
40        location: Some(
41            pair_loc.with_hint("example: JOIN quotes ON trades.timestamp = quotes.timestamp"),
42        ),
43    })?;
44
45    // Cross joins don't require a condition
46    if matches!(join_type, JoinType::Cross) {
47        return Ok(JoinClause {
48            join_type,
49            right,
50            condition: JoinCondition::Natural,
51        });
52    }
53
54    Ok(JoinClause {
55        join_type,
56        right,
57        condition: join_condition,
58    })
59}
60
61/// Parse JOIN type
62///
63/// Grammar: `"inner" | "left" "outer"? | "right" "outer"? | "full" "outer"? | "cross"`
64fn parse_join_type(pair: Pair<Rule>) -> Result<JoinType> {
65    let text = pair.as_str().to_lowercase();
66
67    if text.starts_with("inner") {
68        Ok(JoinType::Inner)
69    } else if text.starts_with("left") {
70        Ok(JoinType::Left)
71    } else if text.starts_with("right") {
72        Ok(JoinType::Right)
73    } else if text.starts_with("full") {
74        Ok(JoinType::Full)
75    } else if text.starts_with("cross") {
76        Ok(JoinType::Cross)
77    } else {
78        // Default to inner
79        Ok(JoinType::Inner)
80    }
81}
82
83/// Parse JOIN source
84///
85/// Grammar: `ident ("as" ident)? | "(" inner_query ")" ("as" ident)?`
86pub fn parse_join_source(pair: Pair<Rule>) -> Result<JoinSource> {
87    let pair_loc = pair_location(&pair);
88    let mut inner_iter = pair.into_inner();
89
90    let first = inner_iter.next().ok_or_else(|| ShapeError::ParseError {
91        message: "expected join source".to_string(),
92        location: Some(pair_loc.clone()),
93    })?;
94
95    match first.as_rule() {
96        Rule::ident => {
97            // Named source (optionally with alias)
98            let name = first.as_str().to_string();
99            // For now, we just use the name (aliases would require extending JoinSource)
100            Ok(JoinSource::Named(name))
101        }
102        Rule::inner_query => {
103            // Subquery
104            let query = super::parse_inner_query(first)?;
105            Ok(JoinSource::Subquery(Box::new(query)))
106        }
107        _ => Err(ShapeError::ParseError {
108            message: format!("unexpected join source type: {:?}", first.as_rule()),
109            location: Some(pair_location(&first)),
110        }),
111    }
112}
113
114/// Parse JOIN condition
115///
116/// Grammar: `"on" expression | "using" "(" ident ("," ident)* ")" | "within" duration`
117fn parse_join_condition(pair: Pair<Rule>) -> Result<JoinCondition> {
118    let pair_loc = pair_location(&pair);
119    let mut inner_iter = pair.into_inner();
120
121    let first = inner_iter.next().ok_or_else(|| ShapeError::ParseError {
122        message: "expected join condition".to_string(),
123        location: Some(pair_loc.clone()),
124    })?;
125
126    match first.as_rule() {
127        Rule::expression => {
128            // ON condition
129            let expr = expressions::parse_expression(first)?;
130            Ok(JoinCondition::On(expr))
131        }
132        Rule::ident => {
133            // USING clause - first identifier already parsed
134            let mut columns = vec![first.as_str().to_string()];
135            for col in inner_iter {
136                if col.as_rule() == Rule::ident {
137                    columns.push(col.as_str().to_string());
138                }
139            }
140            Ok(JoinCondition::Using(columns))
141        }
142        Rule::duration => {
143            // WITHIN clause for temporal join
144            let timeframe = parse_duration_as_timeframe(first)?;
145            Ok(JoinCondition::Temporal {
146                left_time: "timestamp".to_string(),
147                right_time: "timestamp".to_string(),
148                within: timeframe,
149            })
150        }
151        _ => Err(ShapeError::ParseError {
152            message: format!("unexpected join condition type: {:?}", first.as_rule()),
153            location: Some(pair_location(&first)),
154        }),
155    }
156}
157
158/// Parse duration to Timeframe for temporal joins
159fn parse_duration_as_timeframe(pair: Pair<Rule>) -> Result<Timeframe> {
160    use crate::data::TimeframeUnit;
161
162    let text = pair.as_str().to_lowercase();
163    let pair_loc = pair_location(&pair);
164
165    // Parse duration like "100ms", "1s", "5m", etc.
166    let (num_str, unit_str) = extract_duration_parts(&text);
167
168    let value = num_str.parse::<u32>().map_err(|_| ShapeError::ParseError {
169        message: format!("invalid duration value: '{}'", num_str),
170        location: Some(pair_loc.clone()),
171    })?;
172
173    let unit = match unit_str {
174        "s" | "seconds" => TimeframeUnit::Second,
175        "m" | "minutes" => TimeframeUnit::Minute,
176        "h" | "hours" => TimeframeUnit::Hour,
177        "d" | "days" => TimeframeUnit::Day,
178        "w" | "weeks" => TimeframeUnit::Week,
179        "ms" => {
180            // Convert milliseconds to seconds with fractional handling
181            // For simplicity, treat ms as 1 second minimum
182            return Ok(Timeframe::new(1, TimeframeUnit::Second));
183        }
184        _ => {
185            return Err(ShapeError::ParseError {
186                message: format!("unknown duration unit: '{}'", unit_str),
187                location: Some(pair_loc.with_hint("valid units: s, m, h, d, w, ms")),
188            });
189        }
190    };
191
192    Ok(Timeframe::new(value, unit))
193}
194
195/// Extract numeric and unit parts from duration string
196fn extract_duration_parts(s: &str) -> (&str, &str) {
197    let idx = s
198        .find(|c: char| !c.is_ascii_digit() && c != '.')
199        .unwrap_or(s.len());
200    (&s[..idx], &s[idx..])
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206    use pest::Parser;
207
208    fn parse_join(input: &str) -> Result<JoinClause> {
209        let pairs = crate::parser::ShapeParser::parse(Rule::join_clause, input).map_err(|e| {
210            ShapeError::ParseError {
211                message: format!("parse error: {}", e),
212                location: None,
213            }
214        })?;
215        let pair = pairs.into_iter().next().unwrap();
216        parse_join_clause(pair)
217    }
218
219    #[test]
220    fn test_inner_join_on() {
221        let result = parse_join("join quotes on trades.id = quotes.id");
222        assert!(result.is_ok());
223        let join = result.unwrap();
224        assert!(matches!(join.join_type, JoinType::Inner));
225        assert!(matches!(join.condition, JoinCondition::On(_)));
226    }
227
228    #[test]
229    fn test_left_join_using() {
230        let result = parse_join("left join orders using (symbol, timestamp)");
231        assert!(result.is_ok());
232        let join = result.unwrap();
233        assert!(matches!(join.join_type, JoinType::Left));
234        assert!(
235            matches!(&join.condition, JoinCondition::Using(cols) if cols.len() == 2),
236            "Expected Using condition with 2 columns, got {:?}",
237            join.condition
238        );
239        if let JoinCondition::Using(cols) = &join.condition {
240            assert_eq!(cols[0], "symbol");
241            assert_eq!(cols[1], "timestamp");
242        }
243    }
244
245    #[test]
246    fn test_temporal_join() {
247        let result = parse_join("join executions within 100s");
248        assert!(result.is_ok());
249        let join = result.unwrap();
250        assert!(matches!(join.condition, JoinCondition::Temporal { .. }));
251    }
252}