shape_ast/parser/queries/
joins.rs1use crate::ast::{JoinClause, JoinCondition, JoinSource, JoinType};
9use crate::data::Timeframe;
10use crate::error::{Result, ShapeError};
11use crate::parser::{Rule, expressions, pair_location};
12use pest::iterators::Pair;
13
14pub fn parse_join_clause(pair: Pair<Rule>) -> Result<JoinClause> {
18 let pair_loc = pair_location(&pair);
19 let mut join_type = JoinType::Inner; let mut join_source = None;
21 let mut join_condition = JoinCondition::Natural; for inner in pair.into_inner() {
24 match inner.as_rule() {
25 Rule::join_type => {
26 join_type = parse_join_type(inner)?;
27 }
28 Rule::join_source => {
29 join_source = Some(parse_join_source(inner)?);
30 }
31 Rule::join_condition => {
32 join_condition = parse_join_condition(inner)?;
33 }
34 _ => {}
35 }
36 }
37
38 let right = join_source.ok_or_else(|| ShapeError::ParseError {
39 message: "JOIN clause requires a source (table/symbol name or subquery)".to_string(),
40 location: Some(
41 pair_loc.with_hint("example: JOIN quotes ON trades.timestamp = quotes.timestamp"),
42 ),
43 })?;
44
45 if matches!(join_type, JoinType::Cross) {
47 return Ok(JoinClause {
48 join_type,
49 right,
50 condition: JoinCondition::Natural,
51 });
52 }
53
54 Ok(JoinClause {
55 join_type,
56 right,
57 condition: join_condition,
58 })
59}
60
61fn parse_join_type(pair: Pair<Rule>) -> Result<JoinType> {
65 let text = pair.as_str().to_lowercase();
66
67 if text.starts_with("inner") {
68 Ok(JoinType::Inner)
69 } else if text.starts_with("left") {
70 Ok(JoinType::Left)
71 } else if text.starts_with("right") {
72 Ok(JoinType::Right)
73 } else if text.starts_with("full") {
74 Ok(JoinType::Full)
75 } else if text.starts_with("cross") {
76 Ok(JoinType::Cross)
77 } else {
78 Ok(JoinType::Inner)
80 }
81}
82
83pub fn parse_join_source(pair: Pair<Rule>) -> Result<JoinSource> {
87 let pair_loc = pair_location(&pair);
88 let mut inner_iter = pair.into_inner();
89
90 let first = inner_iter.next().ok_or_else(|| ShapeError::ParseError {
91 message: "expected join source".to_string(),
92 location: Some(pair_loc.clone()),
93 })?;
94
95 match first.as_rule() {
96 Rule::ident => {
97 let name = first.as_str().to_string();
99 Ok(JoinSource::Named(name))
101 }
102 Rule::inner_query => {
103 let query = super::parse_inner_query(first)?;
105 Ok(JoinSource::Subquery(Box::new(query)))
106 }
107 _ => Err(ShapeError::ParseError {
108 message: format!("unexpected join source type: {:?}", first.as_rule()),
109 location: Some(pair_location(&first)),
110 }),
111 }
112}
113
114fn parse_join_condition(pair: Pair<Rule>) -> Result<JoinCondition> {
118 let pair_loc = pair_location(&pair);
119 let mut inner_iter = pair.into_inner();
120
121 let first = inner_iter.next().ok_or_else(|| ShapeError::ParseError {
122 message: "expected join condition".to_string(),
123 location: Some(pair_loc.clone()),
124 })?;
125
126 match first.as_rule() {
127 Rule::expression => {
128 let expr = expressions::parse_expression(first)?;
130 Ok(JoinCondition::On(expr))
131 }
132 Rule::ident => {
133 let mut columns = vec![first.as_str().to_string()];
135 for col in inner_iter {
136 if col.as_rule() == Rule::ident {
137 columns.push(col.as_str().to_string());
138 }
139 }
140 Ok(JoinCondition::Using(columns))
141 }
142 Rule::duration => {
143 let timeframe = parse_duration_as_timeframe(first)?;
145 Ok(JoinCondition::Temporal {
146 left_time: "timestamp".to_string(),
147 right_time: "timestamp".to_string(),
148 within: timeframe,
149 })
150 }
151 _ => Err(ShapeError::ParseError {
152 message: format!("unexpected join condition type: {:?}", first.as_rule()),
153 location: Some(pair_location(&first)),
154 }),
155 }
156}
157
158fn parse_duration_as_timeframe(pair: Pair<Rule>) -> Result<Timeframe> {
160 use crate::data::TimeframeUnit;
161
162 let text = pair.as_str().to_lowercase();
163 let pair_loc = pair_location(&pair);
164
165 let (num_str, unit_str) = extract_duration_parts(&text);
167
168 let value = num_str.parse::<u32>().map_err(|_| ShapeError::ParseError {
169 message: format!("invalid duration value: '{}'", num_str),
170 location: Some(pair_loc.clone()),
171 })?;
172
173 let unit = match unit_str {
174 "s" | "seconds" => TimeframeUnit::Second,
175 "m" | "minutes" => TimeframeUnit::Minute,
176 "h" | "hours" => TimeframeUnit::Hour,
177 "d" | "days" => TimeframeUnit::Day,
178 "w" | "weeks" => TimeframeUnit::Week,
179 "ms" => {
180 return Ok(Timeframe::new(1, TimeframeUnit::Second));
183 }
184 _ => {
185 return Err(ShapeError::ParseError {
186 message: format!("unknown duration unit: '{}'", unit_str),
187 location: Some(pair_loc.with_hint("valid units: s, m, h, d, w, ms")),
188 });
189 }
190 };
191
192 Ok(Timeframe::new(value, unit))
193}
194
195fn extract_duration_parts(s: &str) -> (&str, &str) {
197 let idx = s
198 .find(|c: char| !c.is_ascii_digit() && c != '.')
199 .unwrap_or(s.len());
200 (&s[..idx], &s[idx..])
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use pest::Parser;
207
208 fn parse_join(input: &str) -> Result<JoinClause> {
209 let pairs = crate::parser::ShapeParser::parse(Rule::join_clause, input).map_err(|e| {
210 ShapeError::ParseError {
211 message: format!("parse error: {}", e),
212 location: None,
213 }
214 })?;
215 let pair = pairs.into_iter().next().unwrap();
216 parse_join_clause(pair)
217 }
218
219 #[test]
220 fn test_inner_join_on() {
221 let result = parse_join("join quotes on trades.id = quotes.id");
222 assert!(result.is_ok());
223 let join = result.unwrap();
224 assert!(matches!(join.join_type, JoinType::Inner));
225 assert!(matches!(join.condition, JoinCondition::On(_)));
226 }
227
228 #[test]
229 fn test_left_join_using() {
230 let result = parse_join("left join orders using (symbol, timestamp)");
231 assert!(result.is_ok());
232 let join = result.unwrap();
233 assert!(matches!(join.join_type, JoinType::Left));
234 assert!(
235 matches!(&join.condition, JoinCondition::Using(cols) if cols.len() == 2),
236 "Expected Using condition with 2 columns, got {:?}",
237 join.condition
238 );
239 if let JoinCondition::Using(cols) = &join.condition {
240 assert_eq!(cols[0], "symbol");
241 assert_eq!(cols[1], "timestamp");
242 }
243 }
244
245 #[test]
246 fn test_temporal_join() {
247 let result = parse_join("join executions within 100s");
248 assert!(result.is_ok());
249 let join = result.unwrap();
250 assert!(matches!(join.condition, JoinCondition::Temporal { .. }));
251 }
252}