rush_sh/
arithmetic.rs

1use super::state::ShellState;
2
3/// Token types for arithmetic expressions
4#[derive(Debug, Clone, PartialEq)]
5pub enum ArithmeticToken {
6    Number(i64),
7    Variable(String),
8    Operator(ArithmeticOperator),
9    LeftParen,
10    RightParen,
11}
12
13/// Arithmetic operators with their precedence and associativity
14#[derive(Debug, Clone, PartialEq)]
15pub enum ArithmeticOperator {
16    // Unary operators (precedence 100)
17    LogicalNot, // !
18    BitwiseNot, // ~
19
20    // Binary operators in order of precedence (highest to lowest)
21    Multiply,     // *   (precedence 90)
22    Divide,       // /   (precedence 90)
23    Modulo,       // %   (precedence 90)
24    Add,          // +   (precedence 80)
25    Subtract,     // -   (precedence 80)
26    ShiftLeft,    // <<  (precedence 70)
27    ShiftRight,   // >>  (precedence 70)
28    LessThan,     // <   (precedence 60)
29    LessEqual,    // <=  (precedence 60)
30    GreaterThan,  // >   (precedence 60)
31    GreaterEqual, // >=  (precedence 60)
32    Equal,        // ==  (precedence 50)
33    NotEqual,     // !=  (precedence 50)
34    BitwiseAnd,   // &   (precedence 40)
35    BitwiseXor,   // ^   (precedence 30)
36    BitwiseOr,    // |   (precedence 20)
37    LogicalAnd,   // &&  (precedence 10)
38    LogicalOr,    // ||  (precedence 5)
39}
40
41impl ArithmeticOperator {
42    pub fn precedence(&self) -> i32 {
43        match self {
44            ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot => 100,
45
46            ArithmeticOperator::Multiply
47            | ArithmeticOperator::Divide
48            | ArithmeticOperator::Modulo => 90,
49            ArithmeticOperator::Add | ArithmeticOperator::Subtract => 80,
50            ArithmeticOperator::ShiftLeft | ArithmeticOperator::ShiftRight => 70,
51            ArithmeticOperator::LessThan
52            | ArithmeticOperator::LessEqual
53            | ArithmeticOperator::GreaterThan
54            | ArithmeticOperator::GreaterEqual => 60,
55            ArithmeticOperator::Equal | ArithmeticOperator::NotEqual => 50,
56            ArithmeticOperator::BitwiseAnd => 40,
57            ArithmeticOperator::BitwiseXor => 30,
58            ArithmeticOperator::BitwiseOr => 20,
59            ArithmeticOperator::LogicalAnd => 10,
60            ArithmeticOperator::LogicalOr => 5,
61        }
62    }
63
64    pub fn is_unary(&self) -> bool {
65        matches!(
66            self,
67            ArithmeticOperator::LogicalNot | ArithmeticOperator::BitwiseNot
68        )
69    }
70}
71
72/// Errors that can occur during arithmetic evaluation
73#[derive(Debug, Clone)]
74pub enum ArithmeticError {
75    SyntaxError(String),
76    DivisionByZero,
77    UnmatchedParentheses,
78    EmptyExpression,
79}
80
81impl std::fmt::Display for ArithmeticError {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        match self {
84            ArithmeticError::SyntaxError(msg) => write!(f, "Syntax error: {}", msg),
85            ArithmeticError::DivisionByZero => write!(f, "Division by zero"),
86            ArithmeticError::UnmatchedParentheses => write!(f, "Unmatched parentheses"),
87            ArithmeticError::EmptyExpression => write!(f, "Empty expression"),
88        }
89    }
90}
91
92/// Tokenize an arithmetic expression into tokens
93pub fn tokenize_expression(expr: &str) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
94    let mut tokens = Vec::new();
95    let mut chars = expr.chars().peekable();
96
97    while let Some(ch) = chars.next() {
98        match ch {
99            ' ' | '\t' | '\n' => continue, // Skip whitespace
100
101            '(' => tokens.push(ArithmeticToken::LeftParen),
102            ')' => tokens.push(ArithmeticToken::RightParen),
103
104            '+' => {
105                if let Some(next_ch) = chars.peek()
106                    && *next_ch == '+'
107                {
108                    return Err(ArithmeticError::SyntaxError("Unexpected ++".to_string()));
109                }
110                tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Add));
111            }
112
113            '-' => {
114                if let Some(next_ch) = chars.peek()
115                    && *next_ch == '-'
116                {
117                    return Err(ArithmeticError::SyntaxError("Unexpected --".to_string()));
118                }
119                tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Subtract));
120            }
121
122            '*' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Multiply)),
123            '/' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Divide)),
124            '%' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Modulo)),
125
126            '<' => {
127                if let Some(&next_ch) = chars.peek() {
128                    if next_ch == '<' {
129                        chars.next();
130                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftLeft));
131                    } else if next_ch == '=' {
132                        chars.next();
133                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessEqual));
134                    } else {
135                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
136                    }
137                } else {
138                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LessThan));
139                }
140            }
141
142            '>' => {
143                if let Some(&next_ch) = chars.peek() {
144                    if next_ch == '>' {
145                        chars.next();
146                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::ShiftRight));
147                    } else if next_ch == '=' {
148                        chars.next();
149                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterEqual));
150                    } else {
151                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
152                    }
153                } else {
154                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::GreaterThan));
155                }
156            }
157
158            '=' => {
159                if let Some(&next_ch) = chars.peek() {
160                    if next_ch == '=' {
161                        chars.next();
162                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::Equal));
163                    } else {
164                        return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
165                    }
166                } else {
167                    return Err(ArithmeticError::SyntaxError("Unexpected =".to_string()));
168                }
169            }
170
171            '!' => {
172                if let Some(&next_ch) = chars.peek() {
173                    if next_ch == '=' {
174                        chars.next();
175                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::NotEqual));
176                    } else {
177                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
178                    }
179                } else {
180                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalNot));
181                }
182            }
183
184            '&' => {
185                if let Some(&next_ch) = chars.peek() {
186                    if next_ch == '&' {
187                        chars.next();
188                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalAnd));
189                    } else {
190                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
191                    }
192                } else {
193                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseAnd));
194                }
195            }
196
197            '|' => {
198                if let Some(&next_ch) = chars.peek() {
199                    if next_ch == '|' {
200                        chars.next();
201                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::LogicalOr));
202                    } else {
203                        tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
204                    }
205                } else {
206                    tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseOr));
207                }
208            }
209
210            '^' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseXor)),
211            '~' => tokens.push(ArithmeticToken::Operator(ArithmeticOperator::BitwiseNot)),
212
213            // Numbers and variables
214            '0'..='9' => {
215                let mut num_str = String::new();
216                num_str.push(ch);
217                while let Some(&next_ch) = chars.peek() {
218                    if next_ch.is_ascii_digit() {
219                        num_str.push(next_ch);
220                        chars.next();
221                    } else {
222                        break;
223                    }
224                }
225                match num_str.parse::<i64>() {
226                    Ok(num) => tokens.push(ArithmeticToken::Number(num)),
227                    Err(_) => {
228                        return Err(ArithmeticError::SyntaxError("Invalid number".to_string()));
229                    }
230                }
231            }
232
233            // Variables (start with letter or underscore)
234            'a'..='z' | 'A'..='Z' | '_' => {
235                let mut var_name = String::new();
236                var_name.push(ch);
237                while let Some(&next_ch) = chars.peek() {
238                    if next_ch.is_alphanumeric() || next_ch == '_' {
239                        var_name.push(next_ch);
240                        chars.next();
241                    } else {
242                        break;
243                    }
244                }
245                tokens.push(ArithmeticToken::Variable(var_name));
246            }
247
248            _ => {
249                return Err(ArithmeticError::SyntaxError(format!(
250                    "Unexpected character: {}",
251                    ch
252                )));
253            }
254        }
255    }
256
257    Ok(tokens)
258}
259
260/// Parse tokens into Reverse Polish Notation (RPN) using Shunting-yard algorithm
261pub fn parse_to_rpn(tokens: Vec<ArithmeticToken>) -> Result<Vec<ArithmeticToken>, ArithmeticError> {
262    let mut output = Vec::new();
263    let mut operators = Vec::new();
264
265    for token in tokens {
266        match token {
267            ArithmeticToken::Number(_) | ArithmeticToken::Variable(_) => {
268                output.push(token);
269            }
270
271            ArithmeticToken::Operator(op) => {
272                // Handle unary operators
273                if op.is_unary()
274                    && (output.is_empty()
275                        || matches!(
276                            output.last(),
277                            Some(ArithmeticToken::Operator(_) | ArithmeticToken::LeftParen)
278                        ))
279                {
280                    // This is a unary operator
281                    while !operators.is_empty() {
282                        if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
283                            if top_op.precedence() >= op.precedence() && !top_op.is_unary() {
284                                output.push(operators.pop().unwrap());
285                            } else {
286                                break;
287                            }
288                        } else {
289                            break;
290                        }
291                    }
292                    operators.push(ArithmeticToken::Operator(op));
293                } else {
294                    // Binary operator
295                    while !operators.is_empty() {
296                        if let Some(ArithmeticToken::Operator(top_op)) = operators.last() {
297                            if (top_op.precedence() > op.precedence())
298                                || (top_op.precedence() == op.precedence() && !op.is_unary())
299                            {
300                                output.push(operators.pop().unwrap());
301                            } else {
302                                break;
303                            }
304                        } else {
305                            break;
306                        }
307                    }
308                    operators.push(ArithmeticToken::Operator(op));
309                }
310            }
311
312            ArithmeticToken::LeftParen => {
313                operators.push(token);
314            }
315
316            ArithmeticToken::RightParen => {
317                let mut found_left = false;
318                while let Some(op) = operators.pop() {
319                    if op == ArithmeticToken::LeftParen {
320                        found_left = true;
321                        break;
322                    } else {
323                        output.push(op);
324                    }
325                }
326                if !found_left {
327                    return Err(ArithmeticError::UnmatchedParentheses);
328                }
329            }
330        }
331    }
332
333    // Pop remaining operators
334    while let Some(op) = operators.pop() {
335        if op == ArithmeticToken::LeftParen {
336            return Err(ArithmeticError::UnmatchedParentheses);
337        }
338        output.push(op);
339    }
340
341    Ok(output)
342}
343
344/// Evaluate an arithmetic expression in Reverse Polish Notation
345pub fn evaluate_rpn(
346    rpn_tokens: Vec<ArithmeticToken>,
347    shell_state: &ShellState,
348) -> Result<i64, ArithmeticError> {
349    let mut stack = Vec::new();
350
351    for token in rpn_tokens {
352        match token {
353            ArithmeticToken::Number(num) => {
354                stack.push(num);
355            }
356
357            ArithmeticToken::Variable(var_name) => {
358                if let Some(value) = shell_state.get_var(&var_name) {
359                    match value.parse::<i64>() {
360                        Ok(num) => stack.push(num),
361                        Err(_) => {
362                            // Variable exists but is not a valid number, treat as 0 (bash behavior)
363                            stack.push(0)
364                        }
365                    }
366                } else {
367                    // Variable is undefined, treat as 0 (bash behavior)
368                    stack.push(0)
369                }
370            }
371
372            ArithmeticToken::Operator(op) => {
373                if op.is_unary() {
374                    if stack.is_empty() {
375                        return Err(ArithmeticError::SyntaxError(
376                            "Missing operand for unary operator".to_string(),
377                        ));
378                    }
379                    let operand = stack.pop().unwrap();
380                    let result = match op {
381                        ArithmeticOperator::LogicalNot => !operand,
382                        ArithmeticOperator::BitwiseNot => !operand,
383                        _ => unreachable!(),
384                    };
385                    stack.push(result);
386                } else {
387                    if stack.len() < 2 {
388                        return Err(ArithmeticError::SyntaxError(
389                            "Missing operands for binary operator".to_string(),
390                        ));
391                    }
392                    let right = stack.pop().unwrap();
393                    let left = stack.pop().unwrap();
394                    let result = match op {
395                        ArithmeticOperator::Add => left + right,
396                        ArithmeticOperator::Subtract => left - right,
397                        ArithmeticOperator::Multiply => left * right,
398                        ArithmeticOperator::Divide => {
399                            if right == 0 {
400                                return Err(ArithmeticError::DivisionByZero);
401                            }
402                            left / right
403                        }
404                        ArithmeticOperator::Modulo => {
405                            if right == 0 {
406                                return Err(ArithmeticError::DivisionByZero);
407                            }
408                            left % right
409                        }
410                        ArithmeticOperator::ShiftLeft => left << right,
411                        ArithmeticOperator::ShiftRight => left >> right,
412                        ArithmeticOperator::LessThan => {
413                            if left < right {
414                                1
415                            } else {
416                                0
417                            }
418                        }
419                        ArithmeticOperator::LessEqual => {
420                            if left <= right {
421                                1
422                            } else {
423                                0
424                            }
425                        }
426                        ArithmeticOperator::GreaterThan => {
427                            if left > right {
428                                1
429                            } else {
430                                0
431                            }
432                        }
433                        ArithmeticOperator::GreaterEqual => {
434                            if left >= right {
435                                1
436                            } else {
437                                0
438                            }
439                        }
440                        ArithmeticOperator::Equal => {
441                            if left == right {
442                                1
443                            } else {
444                                0
445                            }
446                        }
447                        ArithmeticOperator::NotEqual => {
448                            if left != right {
449                                1
450                            } else {
451                                0
452                            }
453                        }
454                        ArithmeticOperator::BitwiseAnd => left & right,
455                        ArithmeticOperator::BitwiseXor => left ^ right,
456                        ArithmeticOperator::BitwiseOr => left | right,
457                        ArithmeticOperator::LogicalAnd => {
458                            if left != 0 && right != 0 {
459                                1
460                            } else {
461                                0
462                            }
463                        }
464                        ArithmeticOperator::LogicalOr => {
465                            if left != 0 || right != 0 {
466                                1
467                            } else {
468                                0
469                            }
470                        }
471                        _ => unreachable!(),
472                    };
473                    stack.push(result);
474                }
475            }
476
477            ArithmeticToken::LeftParen | ArithmeticToken::RightParen => {
478                return Err(ArithmeticError::SyntaxError(
479                    "Unexpected parenthesis in RPN".to_string(),
480                ));
481            }
482        }
483    }
484
485    if stack.len() != 1 {
486        return Err(ArithmeticError::SyntaxError(
487            "Invalid expression".to_string(),
488        ));
489    }
490
491    Ok(stack[0])
492}
493
494/// Main function to evaluate an arithmetic expression
495pub fn evaluate_arithmetic_expression(
496    expr: &str,
497    shell_state: &ShellState,
498) -> Result<i64, ArithmeticError> {
499    if expr.trim().is_empty() {
500        return Err(ArithmeticError::EmptyExpression);
501    }
502
503    let tokens = tokenize_expression(expr)?;
504    let rpn_tokens = parse_to_rpn(tokens)?;
505    let result = evaluate_rpn(rpn_tokens, shell_state)?;
506
507    Ok(result)
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    #[test]
515    fn test_tokenize_simple_numbers() {
516        let tokens = tokenize_expression("42").unwrap();
517        assert_eq!(tokens, vec![ArithmeticToken::Number(42)]);
518    }
519
520    #[test]
521    fn test_tokenize_operators() {
522        let tokens = tokenize_expression("2+3").unwrap();
523        assert_eq!(
524            tokens,
525            vec![
526                ArithmeticToken::Number(2),
527                ArithmeticToken::Operator(ArithmeticOperator::Add),
528                ArithmeticToken::Number(3)
529            ]
530        );
531    }
532
533    #[test]
534    fn test_tokenize_parentheses() {
535        let tokens = tokenize_expression("(2+3)").unwrap();
536        assert_eq!(
537            tokens,
538            vec![
539                ArithmeticToken::LeftParen,
540                ArithmeticToken::Number(2),
541                ArithmeticToken::Operator(ArithmeticOperator::Add),
542                ArithmeticToken::Number(3),
543                ArithmeticToken::RightParen
544            ]
545        );
546    }
547
548    #[test]
549    fn test_tokenize_variables() {
550        let tokens = tokenize_expression("x+y").unwrap();
551        assert_eq!(
552            tokens,
553            vec![
554                ArithmeticToken::Variable("x".to_string()),
555                ArithmeticToken::Operator(ArithmeticOperator::Add),
556                ArithmeticToken::Variable("y".to_string())
557            ]
558        );
559    }
560
561    #[test]
562    fn test_evaluate_simple() {
563        let shell_state = ShellState::new();
564        let result = evaluate_arithmetic_expression("42", &shell_state).unwrap();
565        assert_eq!(result, 42);
566    }
567
568    #[test]
569    fn test_evaluate_addition() {
570        let shell_state = ShellState::new();
571        let result = evaluate_arithmetic_expression("2+3", &shell_state).unwrap();
572        assert_eq!(result, 5);
573    }
574
575    #[test]
576    fn test_evaluate_with_precedence() {
577        let shell_state = ShellState::new();
578        let result = evaluate_arithmetic_expression("2+3*4", &shell_state).unwrap();
579        assert_eq!(result, 14); // 3*4 = 12, +2 = 14
580    }
581
582    #[test]
583    fn test_evaluate_with_parentheses() {
584        let shell_state = ShellState::new();
585        let result = evaluate_arithmetic_expression("(2+3)*4", &shell_state).unwrap();
586        assert_eq!(result, 20); // (2+3) = 5, *4 = 20
587    }
588
589    #[test]
590    fn test_evaluate_comparison() {
591        let shell_state = ShellState::new();
592        let result = evaluate_arithmetic_expression("5>3", &shell_state).unwrap();
593        assert_eq!(result, 1); // true
594
595        let result = evaluate_arithmetic_expression("3>5", &shell_state).unwrap();
596        assert_eq!(result, 0); // false
597    }
598
599    #[test]
600    fn test_evaluate_variable() {
601        let mut shell_state = ShellState::new();
602        shell_state.set_var("x", "10".to_string());
603        let result = evaluate_arithmetic_expression("x + 5", &shell_state).unwrap();
604        assert_eq!(result, 15);
605    }
606
607    #[test]
608    fn test_evaluate_division_by_zero() {
609        let shell_state = ShellState::new();
610        let result = evaluate_arithmetic_expression("5/0", &shell_state);
611        assert!(matches!(result, Err(ArithmeticError::DivisionByZero)));
612    }
613
614    #[test]
615    fn test_evaluate_undefined_variable() {
616        let shell_state = ShellState::new();
617        let result = evaluate_arithmetic_expression("undefined + 5", &shell_state);
618        // Undefined variables are treated as 0 (bash behavior)
619        assert_eq!(result.unwrap(), 5);
620    }
621}