1use crate::{hir, ty::Gcx};
2use alloy_primitives::U256;
3use solar_ast::LitKind;
4use solar_interface::{Span, diagnostics::ErrorGuaranteed};
5use std::fmt;
6
7const RECURSION_LIMIT: usize = 64;
8
9pub struct ConstantEvaluator<'gcx> {
18 pub gcx: Gcx<'gcx>,
19 depth: usize,
20}
21
22type EvalResult<'gcx> = Result<IntScalar, EvalError>;
23
24impl<'gcx> ConstantEvaluator<'gcx> {
25 pub fn new(gcx: Gcx<'gcx>) -> Self {
27 Self { gcx, depth: 0 }
28 }
29
30 pub fn eval(&mut self, expr: &hir::Expr<'_>) -> Result<IntScalar, ErrorGuaranteed> {
32 self.try_eval(expr).map_err(|err| self.emit_eval_error(expr, err))
33 }
34
35 pub fn try_eval(&mut self, expr: &hir::Expr<'_>) -> EvalResult<'gcx> {
37 self.depth += 1;
38 if self.depth > RECURSION_LIMIT {
39 return Err(EE::RecursionLimitReached.spanned(expr.span));
40 }
41 let mut res = self.eval_expr(expr);
42 if let Err(e) = &mut res
43 && e.span.is_dummy()
44 {
45 e.span = expr.span;
46 }
47 self.depth = self.depth.checked_sub(1).unwrap();
48 res
49 }
50
51 pub fn emit_eval_error(&self, expr: &hir::Expr<'_>, err: EvalError) -> ErrorGuaranteed {
53 match err.kind {
54 EE::AlreadyEmitted(guar) => guar,
55 _ => {
56 let msg = format!("failed to evaluate constant: {}", err.kind.msg());
57 let label = "evaluation of constant value failed here";
58 self.gcx.dcx().err(msg).span(expr.span).span_label(err.span, label).emit()
59 }
60 }
61 }
62
63 fn eval_expr(&mut self, expr: &hir::Expr<'_>) -> EvalResult<'gcx> {
64 let expr = expr.peel_parens();
65 match expr.kind {
66 hir::ExprKind::Binary(l, bin_op, r) => {
69 let l = self.try_eval(l)?;
70 let r = self.try_eval(r)?;
71 l.binop(&r, bin_op.kind).map_err(Into::into)
72 }
73 hir::ExprKind::Ident(res) => {
77 let Some(v) = res.iter().find_map(|res| res.as_variable()) else {
79 return Err(EE::NonConstantVar.into());
80 };
81
82 let v = self.gcx.hir.variable(v);
83 if v.mutability != Some(hir::VarMut::Constant) {
84 return Err(EE::NonConstantVar.into());
85 }
86 self.try_eval(v.initializer.expect("constant variable has no initializer"))
87 }
88 hir::ExprKind::Lit(lit) => self.eval_lit(lit),
91 hir::ExprKind::Unary(un_op, v) => {
102 let v = self.try_eval(v)?;
103 v.unop(un_op.kind).map_err(Into::into)
104 }
105 hir::ExprKind::Err(guar) => Err(EE::AlreadyEmitted(guar).into()),
106 _ => Err(EE::UnsupportedExpr.into()),
107 }
108 }
109
110 fn eval_lit(&mut self, lit: &hir::Lit<'_>) -> EvalResult<'gcx> {
111 match lit.kind {
112 LitKind::Number(n) => Ok(IntScalar::new(n)),
114 LitKind::Address(address) => Ok(IntScalar::from_be_bytes(address.as_slice())),
116 LitKind::Bool(bool) => Ok(IntScalar::from_be_bytes(&[bool as u8])),
117 LitKind::Err(guar) => Err(EE::AlreadyEmitted(guar).into()),
118 _ => Err(EE::UnsupportedLiteral.into()),
119 }
120 }
121}
122
123pub struct IntScalar {
124 pub data: U256,
125}
126
127impl IntScalar {
128 pub fn new(data: U256) -> Self {
129 Self { data }
130 }
131
132 pub fn from_bool(value: bool) -> Self {
134 Self { data: U256::from(value as u8) }
135 }
136
137 pub fn from_be_bytes(bytes: &[u8]) -> Self {
143 Self { data: U256::from_be_slice(bytes) }
144 }
145
146 pub fn to_bool(&self) -> bool {
148 !self.data.is_zero()
149 }
150
151 pub fn unop(&self, op: hir::UnOpKind) -> Result<Self, EE> {
153 Ok(match op {
154 hir::UnOpKind::PreInc
155 | hir::UnOpKind::PreDec
156 | hir::UnOpKind::PostInc
157 | hir::UnOpKind::PostDec => return Err(EE::UnsupportedUnaryOp),
158 hir::UnOpKind::Not | hir::UnOpKind::BitNot => Self::new(!self.data),
159 hir::UnOpKind::Neg => Self::new(self.data.wrapping_neg()),
160 })
161 }
162
163 pub fn binop(&self, r: &Self, op: hir::BinOpKind) -> Result<Self, EE> {
165 let l = self;
166 Ok(match op {
167 hir::BinOpKind::BitOr => Self::new(l.data | r.data),
176 hir::BinOpKind::BitAnd => Self::new(l.data & r.data),
177 hir::BinOpKind::BitXor => Self::new(l.data ^ r.data),
178 hir::BinOpKind::Shr => {
179 Self::new(l.data.wrapping_shr(r.data.try_into().unwrap_or(usize::MAX)))
180 }
181 hir::BinOpKind::Shl => {
182 Self::new(l.data.wrapping_shl(r.data.try_into().unwrap_or(usize::MAX)))
183 }
184 hir::BinOpKind::Sar => {
185 Self::new(l.data.arithmetic_shr(r.data.try_into().unwrap_or(usize::MAX)))
186 }
187 hir::BinOpKind::Add => {
188 Self::new(l.data.checked_add(r.data).ok_or(EE::ArithmeticOverflow)?)
189 }
190 hir::BinOpKind::Sub => {
191 Self::new(l.data.checked_sub(r.data).ok_or(EE::ArithmeticOverflow)?)
192 }
193 hir::BinOpKind::Pow => {
194 Self::new(l.data.checked_pow(r.data).ok_or(EE::ArithmeticOverflow)?)
195 }
196 hir::BinOpKind::Mul => {
197 Self::new(l.data.checked_mul(r.data).ok_or(EE::ArithmeticOverflow)?)
198 }
199 hir::BinOpKind::Div => Self::new(l.data.checked_div(r.data).ok_or(EE::DivisionByZero)?),
200 hir::BinOpKind::Rem => Self::new(l.data.checked_rem(r.data).ok_or(EE::DivisionByZero)?),
201 hir::BinOpKind::Lt
202 | hir::BinOpKind::Le
203 | hir::BinOpKind::Gt
204 | hir::BinOpKind::Ge
205 | hir::BinOpKind::Eq
206 | hir::BinOpKind::Ne
207 | hir::BinOpKind::Or
208 | hir::BinOpKind::And => return Err(EE::UnsupportedBinaryOp),
209 })
210 }
211}
212
213#[derive(Debug)]
214pub enum EvalErrorKind {
215 RecursionLimitReached,
216 ArithmeticOverflow,
217 DivisionByZero,
218 UnsupportedLiteral,
219 UnsupportedUnaryOp,
220 UnsupportedBinaryOp,
221 UnsupportedExpr,
222 NonConstantVar,
223 AlreadyEmitted(ErrorGuaranteed),
224}
225use EvalErrorKind as EE;
226
227impl EvalErrorKind {
228 pub fn spanned(self, span: Span) -> EvalError {
229 EvalError { kind: self, span }
230 }
231
232 fn msg(&self) -> &'static str {
233 match self {
234 Self::RecursionLimitReached => "recursion limit reached",
235 Self::ArithmeticOverflow => "arithmetic overflow",
236 Self::DivisionByZero => "attempted to divide by zero",
237 Self::UnsupportedLiteral => "unsupported literal",
238 Self::UnsupportedUnaryOp => "unsupported unary operation",
239 Self::UnsupportedBinaryOp => "unsupported binary operation",
240 Self::UnsupportedExpr => "unsupported expression",
241 Self::NonConstantVar => "only constant variables are allowed",
242 Self::AlreadyEmitted(_) => unreachable!(),
243 }
244 }
245}
246
247#[derive(Debug)]
248pub struct EvalError {
249 pub span: Span,
250 pub kind: EvalErrorKind,
251}
252
253impl From<EE> for EvalError {
254 fn from(value: EE) -> Self {
255 Self { kind: value, span: Span::DUMMY }
256 }
257}
258
259impl fmt::Display for EvalError {
260 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
261 self.kind.msg().fmt(f)
262 }
263}
264
265impl std::error::Error for EvalError {}