tensorlogic_cli/
parser.rs

1//! Enhanced expression parser for TensorLogic CLI
2//!
3//! Supports:
4//! - Basic predicates: pred(x, y)
5//! - Logical operators: AND, OR, NOT, IMPLIES
6//! - Quantifiers: EXISTS, FORALL
7//! - Arithmetic: +, -, *, /
8//! - Comparisons: =, <, >, <=, >=, !=
9//! - Conditionals: IF-THEN-ELSE
10//! - Parentheses for grouping
11
12use anyhow::{bail, Result};
13use tensorlogic_ir::{TLExpr, Term};
14
15/// Parse expression string with enhanced syntax support
16pub fn parse_expression(input: &str) -> Result<TLExpr> {
17    let input = input.trim();
18
19    if input.is_empty() {
20        bail!("Empty expression");
21    }
22
23    // Handle IF-THEN-ELSE first to avoid operator splitting
24    if input.starts_with("IF ") || input.starts_with("if ") {
25        return parse_conditional(input);
26    }
27
28    // Parse with operator precedence
29    parse_implication(input)
30}
31
32// Operator precedence (lowest to highest):
33// 1. IMPLIES (→)
34// 2. OR (|, ||)
35// 3. AND (&, &&)
36// 4. Comparisons (=, <, >, <=, >=, !=)
37// 5. Arithmetic (+, -)
38// 6. Arithmetic (*, /)
39// 7. NOT (~, !)
40// 8. Quantifiers (EXISTS, FORALL)
41// 9. Conditionals (IF-THEN-ELSE)
42// 10. Predicates
43
44fn parse_implication(input: &str) -> Result<TLExpr> {
45    if let Some(pos) = find_operator(input, &["->", "IMPLIES", "=>", "→"]) {
46        let left = parse_or(input[..pos].trim())?;
47        let right = parse_implication(input[pos + operator_len(&input[pos..])..].trim())?;
48        return Ok(TLExpr::imply(left, right));
49    }
50    parse_or(input)
51}
52
53fn parse_or(input: &str) -> Result<TLExpr> {
54    if let Some(pos) = find_operator(input, &[" OR ", " | ", "||"]) {
55        let left = parse_and(input[..pos].trim())?;
56        let right = parse_or(input[pos + operator_len(&input[pos..])..].trim())?;
57        return Ok(TLExpr::Or(Box::new(left), Box::new(right)));
58    }
59    parse_and(input)
60}
61
62fn parse_and(input: &str) -> Result<TLExpr> {
63    if let Some(pos) = find_operator(input, &[" AND ", " & ", "&&", "∧"]) {
64        let left = parse_comparison(input[..pos].trim())?;
65        let right = parse_and(input[pos + operator_len(&input[pos..])..].trim())?;
66        return Ok(TLExpr::And(Box::new(left), Box::new(right)));
67    }
68    parse_comparison(input)
69}
70
71fn parse_comparison(input: &str) -> Result<TLExpr> {
72    // Check for comparison operators
73    if let Some(pos) = find_operator(input, &[" = ", " == "]) {
74        let left = parse_additive(input[..pos].trim())?;
75        let right = parse_additive(input[pos + operator_len(&input[pos..])..].trim())?;
76        return Ok(TLExpr::Eq(Box::new(left), Box::new(right)));
77    }
78
79    if let Some(pos) = find_operator(input, &[" <= ", " ≤ "]) {
80        let left = parse_additive(input[..pos].trim())?;
81        let right = parse_additive(input[pos + operator_len(&input[pos..])..].trim())?;
82        return Ok(TLExpr::Lte(Box::new(left), Box::new(right)));
83    }
84
85    if let Some(pos) = find_operator(input, &[" >= ", " ≥ "]) {
86        let left = parse_additive(input[..pos].trim())?;
87        let right = parse_additive(input[pos + operator_len(&input[pos..])..].trim())?;
88        return Ok(TLExpr::Gte(Box::new(left), Box::new(right)));
89    }
90
91    if let Some(pos) = find_operator(input, &[" < "]) {
92        let left = parse_additive(input[..pos].trim())?;
93        let right = parse_additive(input[pos + operator_len(&input[pos..])..].trim())?;
94        return Ok(TLExpr::Lt(Box::new(left), Box::new(right)));
95    }
96
97    if let Some(pos) = find_operator(input, &[" > "]) {
98        let left = parse_additive(input[..pos].trim())?;
99        let right = parse_additive(input[pos + operator_len(&input[pos..])..].trim())?;
100        return Ok(TLExpr::Gt(Box::new(left), Box::new(right)));
101    }
102
103    if let Some(pos) = find_operator(input, &[" != ", " ≠ "]) {
104        let left = parse_additive(input[..pos].trim())?;
105        let right = parse_additive(input[pos + operator_len(&input[pos..])..].trim())?;
106        let eq = TLExpr::Eq(Box::new(left), Box::new(right));
107        return Ok(TLExpr::Not(Box::new(eq)));
108    }
109
110    parse_additive(input)
111}
112
113fn parse_additive(input: &str) -> Result<TLExpr> {
114    if let Some(pos) = find_operator(input, &[" + "]) {
115        let left = parse_multiplicative(input[..pos].trim())?;
116        let right = parse_additive(input[pos + 3..].trim())?;
117        return Ok(TLExpr::Add(Box::new(left), Box::new(right)));
118    }
119
120    if let Some(pos) = find_operator(input, &[" - "]) {
121        let left = parse_multiplicative(input[..pos].trim())?;
122        let right = parse_additive(input[pos + 3..].trim())?;
123        return Ok(TLExpr::Sub(Box::new(left), Box::new(right)));
124    }
125
126    parse_multiplicative(input)
127}
128
129fn parse_multiplicative(input: &str) -> Result<TLExpr> {
130    if let Some(pos) = find_operator(input, &[" * ", " × "]) {
131        let left = parse_unary(input[..pos].trim())?;
132        let right = parse_multiplicative(input[pos + operator_len(&input[pos..])..].trim())?;
133        return Ok(TLExpr::Mul(Box::new(left), Box::new(right)));
134    }
135
136    if let Some(pos) = find_operator(input, &[" / ", " ÷ "]) {
137        let left = parse_unary(input[..pos].trim())?;
138        let right = parse_multiplicative(input[pos + operator_len(&input[pos..])..].trim())?;
139        return Ok(TLExpr::Div(Box::new(left), Box::new(right)));
140    }
141
142    parse_unary(input)
143}
144
145fn parse_unary(input: &str) -> Result<TLExpr> {
146    let input = input.trim();
147
148    // Handle NOT
149    for prefix in &["NOT ", "not ", "~", "!", "¬"] {
150        if let Some(rest) = input.strip_prefix(prefix) {
151            let body = parse_unary(rest.trim())?;
152            return Ok(TLExpr::Not(Box::new(body)));
153        }
154    }
155
156    parse_primary(input)
157}
158
159fn parse_primary(input: &str) -> Result<TLExpr> {
160    let input = input.trim();
161
162    // Handle quantifiers
163    if let Some(rest) = input
164        .strip_prefix("EXISTS ")
165        .or_else(|| input.strip_prefix("exists "))
166        .or_else(|| input.strip_prefix("∃ "))
167    {
168        return parse_quantifier(rest, true);
169    }
170
171    if let Some(rest) = input
172        .strip_prefix("FORALL ")
173        .or_else(|| input.strip_prefix("forall "))
174        .or_else(|| input.strip_prefix("∀ "))
175    {
176        return parse_quantifier(rest, false);
177    }
178
179    // Handle parentheses
180    if input.starts_with('(') && input.ends_with(')') {
181        let inner = &input[1..input.len() - 1];
182        if is_balanced(inner) {
183            return parse_expression(inner);
184        }
185    }
186
187    // Handle numeric constants
188    if let Ok(value) = input.parse::<f64>() {
189        return Ok(TLExpr::Constant(value));
190    }
191
192    // Handle predicates
193    if let Some(paren_pos) = input.find('(') {
194        if input.ends_with(')') {
195            let name = input[..paren_pos].trim();
196            let args_str = &input[paren_pos + 1..input.len() - 1];
197
198            let args: Vec<Term> = if args_str.trim().is_empty() {
199                vec![]
200            } else {
201                args_str
202                    .split(',')
203                    .map(|a| parse_term(a.trim()))
204                    .collect::<Result<Vec<_>>>()?
205            };
206
207            return Ok(TLExpr::pred(name, args));
208        }
209    }
210
211    // Single variable or constant
212    Ok(TLExpr::pred(input, vec![]))
213}
214
215fn parse_quantifier(input: &str, is_exists: bool) -> Result<TLExpr> {
216    // Format: "x IN Domain. body" or "x. body"
217    let parts: Vec<&str> = input.splitn(2, '.').collect();
218    if parts.len() != 2 {
219        bail!(
220            "Invalid quantifier syntax: expected '{} VAR [IN DOMAIN]. BODY'",
221            if is_exists { "EXISTS" } else { "FORALL" }
222        );
223    }
224
225    let var_part = parts[0].trim();
226    let body = parse_expression(parts[1].trim())?;
227
228    // Check for "IN Domain" syntax
229    let (var, domain) = if let Some(in_pos) = var_part.find(" IN ") {
230        let var = var_part[..in_pos].trim();
231        let domain = var_part[in_pos + 4..].trim();
232        (var, domain)
233    } else if let Some(in_pos) = var_part.find(" in ") {
234        let var = var_part[..in_pos].trim();
235        let domain = var_part[in_pos + 4..].trim();
236        (var, domain)
237    } else {
238        // Default domain "D"
239        (var_part, "D")
240    };
241
242    if is_exists {
243        Ok(TLExpr::exists(var, domain, body))
244    } else {
245        Ok(TLExpr::forall(var, domain, body))
246    }
247}
248
249fn parse_conditional(input: &str) -> Result<TLExpr> {
250    // Format: "IF cond THEN then_expr ELSE else_expr"
251    let input = input
252        .strip_prefix("IF ")
253        .or_else(|| input.strip_prefix("if "))
254        .unwrap();
255
256    let then_pos = input
257        .find(" THEN ")
258        .or_else(|| input.find(" then "))
259        .ok_or_else(|| anyhow::anyhow!("Missing THEN in IF-THEN-ELSE"))?;
260
261    let else_pos = input
262        .find(" ELSE ")
263        .or_else(|| input.find(" else "))
264        .ok_or_else(|| anyhow::anyhow!("Missing ELSE in IF-THEN-ELSE"))?;
265
266    let cond = parse_expression(input[..then_pos].trim())?;
267    let then_expr = parse_expression(input[then_pos + 6..else_pos].trim())?;
268    let else_expr = parse_expression(input[else_pos + 6..].trim())?;
269
270    Ok(TLExpr::IfThenElse {
271        condition: Box::new(cond),
272        then_branch: Box::new(then_expr),
273        else_branch: Box::new(else_expr),
274    })
275}
276
277fn parse_term(input: &str) -> Result<Term> {
278    let input = input.trim();
279
280    // Check if it's a quoted string (constant)
281    if input.starts_with('"') && input.ends_with('"') {
282        Ok(Term::Const(input[1..input.len() - 1].to_string()))
283    } else {
284        // Variable
285        Ok(Term::var(input))
286    }
287}
288
289/// Find the position of an operator at the top level (not inside parentheses)
290fn find_operator(input: &str, operators: &[&str]) -> Option<usize> {
291    let mut depth = 0;
292    let chars: Vec<char> = input.chars().collect();
293
294    for i in 0..chars.len() {
295        match chars[i] {
296            '(' => depth += 1,
297            ')' => depth -= 1,
298            _ => {}
299        }
300
301        if depth == 0 {
302            for op in operators {
303                if input[i..].starts_with(op) {
304                    return Some(i);
305                }
306            }
307        }
308    }
309
310    None
311}
312
313/// Get the length of the operator at the given position
314fn operator_len(input: &str) -> usize {
315    let operators = vec![
316        "->", "IMPLIES", "=>", "→", " OR ", " | ", "||", " AND ", " & ", "&&", "∧", " = ", " == ",
317        " <= ", " ≥ ", " >= ", " ≥ ", " < ", " > ", " != ", " ≠ ", " + ", " - ", " * ", " × ",
318        " / ", " ÷ ",
319    ];
320
321    for op in operators {
322        if input.starts_with(op) {
323            return op.len();
324        }
325    }
326
327    1
328}
329
330/// Check if parentheses are balanced
331fn is_balanced(input: &str) -> bool {
332    let mut depth = 0;
333    for ch in input.chars() {
334        match ch {
335            '(' => depth += 1,
336            ')' => {
337                depth -= 1;
338                if depth < 0 {
339                    return false;
340                }
341            }
342            _ => {}
343        }
344    }
345    depth == 0
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_simple_predicate() {
354        let expr = parse_expression("knows(x, y)").unwrap();
355        assert!(matches!(expr, TLExpr::Pred { .. }));
356    }
357
358    #[test]
359    fn test_and_operation() {
360        let expr = parse_expression("p(x) AND q(y)").unwrap();
361        assert!(matches!(expr, TLExpr::And(_, _)));
362    }
363
364    #[test]
365    fn test_arithmetic() {
366        let expr = parse_expression("x + y").unwrap();
367        assert!(matches!(expr, TLExpr::Add(_, _)));
368    }
369
370    #[test]
371    fn test_comparison() {
372        let expr = parse_expression("x < y").unwrap();
373        assert!(matches!(expr, TLExpr::Lt(_, _)));
374    }
375
376    #[test]
377    fn test_quantifier() {
378        let expr = parse_expression("EXISTS x IN Person. knows(x, y)").unwrap();
379        assert!(matches!(expr, TLExpr::Exists { .. }));
380    }
381
382    #[test]
383    fn test_conditional() {
384        let expr = parse_expression("IF x < 0 THEN 0 ELSE x").unwrap();
385        assert!(matches!(expr, TLExpr::IfThenElse { .. }));
386    }
387
388    #[test]
389    fn test_complex_expression() {
390        let expr = parse_expression("(p(x) OR q(y)) AND r(z)").unwrap();
391        assert!(matches!(expr, TLExpr::And(_, _)));
392    }
393}