1use std::{fmt::Display, ops::Neg};
2
3use num_bigint::BigInt;
4use num_traits::cast::ToPrimitive;
5use t4_idl_parser::expr::{ConstExpr, Literal, UnaryOpExpr};
6
7#[derive(Debug)]
8pub enum ConstValue {
9 Integer(BigInt),
10 Boolean(bool),
11 Char(char),
12 String(String),
13 FloatingPoint(f64),
14}
15
16impl Display for ConstValue {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 match self {
19 ConstValue::Integer(x) => write!(f, "{x}"),
20 ConstValue::Boolean(x) => write!(f, "{x}"),
21 ConstValue::Char(x) => write!(f, "'{x}'"),
22 ConstValue::String(x) => {
23 let mut i = 1;
24 let sharp = loop {
25 let sharp = "#".repeat(i);
26 if !x.contains(&sharp) {
27 break sharp;
28 }
29 i += 1;
30 };
31 write!(f, "r{sharp}\"{x}\"{sharp}")
32 }
33 ConstValue::FloatingPoint(x) => write!(f, "{x}"),
34 }
35 }
36}
37
38pub fn eval(expr: &ConstExpr) -> ConstValue {
39 match expr {
40 ConstExpr::Literal(n) => eval_literal(n),
41 ConstExpr::Add(left, right) => {
42 arithmetic_op(left, right, &|n1, n2| n1 + n2, &|n1, n2| n1 + n2)
43 }
44 ConstExpr::Sub(left, right) => {
45 arithmetic_op(left, right, &|n1, n2| n1 - n2, &|n1, n2| n1 - n2)
46 }
47 ConstExpr::Div(left, right) => {
48 arithmetic_op(left, right, &|n1, n2| n1 / n2, &|n1, n2| n1 / n2)
49 }
50 ConstExpr::Mul(left, right) => {
51 arithmetic_op(left, right, &|n1, n2| n1 * n2, &|n1, n2| n1 * n2)
52 }
53 ConstExpr::And(left, right) => boolean_op(left, right, &|n1, n2| n1 && n2),
54 ConstExpr::Or(left, right) => boolean_op(left, right, &|n1, n2| n1 || n2),
55 ConstExpr::Xor(left, right) => boolean_op(left, right, &|n1, n2| n1 ^ n2),
56 ConstExpr::LShift(left, right) => int_op(left, right, &|n1, n2| n1 << n2.to_u64().unwrap()),
57 ConstExpr::RShift(left, right) => int_op(left, right, &|n1, n2| n1 >> n2.to_u64().unwrap()),
58 ConstExpr::Mod(left, right) => {
59 arithmetic_op(left, right, &|n1, n2| n1 % n2, &|n1, n2| n1 % n2)
60 }
61 ConstExpr::UnaryOp(e) => eval_unary_op(e),
62 ConstExpr::ScopedName(_n) => todo!(),
63 }
64}
65
66fn eval_unary_op(expr: &UnaryOpExpr) -> ConstValue {
67 match expr {
68 UnaryOpExpr::Minus(e) => {
69 let n = eval(e);
70 match n {
71 ConstValue::Integer(n) => ConstValue::Integer(n.neg()),
72 ConstValue::FloatingPoint(n) => ConstValue::FloatingPoint(-n),
73 _ => panic!("{:?} is not a number", e),
74 }
75 }
76 UnaryOpExpr::Plus(e) => {
77 let n = eval(e);
78 match n {
79 ConstValue::Integer(n) => ConstValue::Integer(n),
80 ConstValue::FloatingPoint(n) => ConstValue::FloatingPoint(n),
81 _ => panic!("{:?} is not a number", e),
82 }
83 }
84 UnaryOpExpr::Negate(e) => {
85 let n = eval(e);
86 match n {
87 ConstValue::Boolean(n) => ConstValue::Boolean(!n),
88 _ => panic!("{:?} is not a boolean", e),
89 }
90 }
91 }
92}
93
94fn eval_literal(expr: &Literal) -> ConstValue {
95 match expr {
96 Literal::Integer(n) => ConstValue::Integer(n.clone()),
97 Literal::Boolean(n) => ConstValue::Boolean(*n),
98 Literal::Char(n) => ConstValue::Char(*n),
99 Literal::String(n) => ConstValue::String(n.clone()),
100 Literal::FloatingPoint(n) => ConstValue::FloatingPoint(*n),
101 Literal::FixedPoint(_) => unimplemented!(),
102 }
103}
104
105fn arithmetic_op(
106 left: &ConstExpr,
107 right: &ConstExpr,
108 int_fn: &dyn Fn(BigInt, BigInt) -> BigInt,
109 float_fn: &dyn Fn(f64, f64) -> f64,
110) -> ConstValue {
111 let n1 = eval(left);
112 assert!(matches!(
113 n1,
114 ConstValue::Integer(_) | ConstValue::FloatingPoint(_)
115 ));
116
117 let n2 = eval(right);
118 assert!(matches!(
119 n2,
120 ConstValue::Integer(_) | ConstValue::FloatingPoint(_)
121 ));
122
123 match (n1, n2) {
124 (ConstValue::Integer(n1), ConstValue::Integer(n2)) => ConstValue::Integer(int_fn(n1, n2)),
125 (ConstValue::FloatingPoint(n1), ConstValue::FloatingPoint(n2)) => {
126 ConstValue::FloatingPoint(float_fn(n1, n2))
127 }
128 _ => panic!("{:?} or/and {:?} is/are not (a) number(s)", left, right),
129 }
130}
131
132fn boolean_op(
133 left: &ConstExpr,
134 right: &ConstExpr,
135 func: &dyn Fn(bool, bool) -> bool,
136) -> ConstValue {
137 let n1 = eval(left);
138 assert!(matches!(n1, ConstValue::Boolean(_)));
139
140 let n2 = eval(right);
141 assert!(matches!(n2, ConstValue::Boolean(_)));
142
143 match (n1, n2) {
144 (ConstValue::Boolean(n1), ConstValue::Boolean(n2)) => ConstValue::Boolean(func(n1, n2)),
145 _ => unreachable!(),
146 }
147}
148
149fn int_op(
150 left: &ConstExpr,
151 right: &ConstExpr,
152 func: &dyn Fn(BigInt, BigInt) -> BigInt,
153) -> ConstValue {
154 let n1 = eval(left);
155 assert!(matches!(n1, ConstValue::Integer(_)));
156
157 let n2 = eval(right);
158 assert!(matches!(n2, ConstValue::Integer(_)));
159
160 match (n1, n2) {
161 (ConstValue::Integer(n1), ConstValue::Integer(n2)) => ConstValue::Integer(func(n1, n2)),
162 _ => unreachable!(),
163 }
164}