Skip to main content

sentio_core/rules/anchor/
unchecked_arithmetic.rs

1use crate::finding::SourceLocation;
2use crate::rules::{Rule, RuleContext, RuleMatch, RuleMetadata, RuleSeverity};
3use crate::syntax::ParsedFile;
4use quote::ToTokens;
5use syn::spanned::Spanned;
6use syn::visit::{self, Visit};
7use syn::{BinOp, ExprBinary};
8
9#[derive(Debug, Default)]
10pub struct UncheckedArithmeticRule;
11
12impl Rule for UncheckedArithmeticRule {
13    fn metadata(&self) -> &RuleMetadata {
14        static METADATA: RuleMetadata = RuleMetadata {
15            id: "SW005",
16            title: "Unchecked arithmetic",
17            severity: RuleSeverity::High,
18            description: "Detects arithmetic operations (+, -, *) on account data that can \
19                silently overflow or underflow in release builds, where Rust wraps by default.",
20            fix_guidance: "Use checked_add(), checked_sub(), or checked_mul() and propagate \
21                the error with ?, or use saturating_add()/saturating_sub() when wrapping is intentional.",
22        };
23        &METADATA
24    }
25
26    fn match_file(&self, file: &ParsedFile, _ctx: &RuleContext<'_>) -> Vec<RuleMatch> {
27        let mut collector = ArithmeticCollector {
28            findings: Vec::new(),
29        };
30        visit::visit_file(&mut collector, &file.syntax);
31
32        collector
33            .findings
34            .into_iter()
35            .map(|(message, line, column)| RuleMatch {
36                rule_id: "SW005",
37                severity: RuleSeverity::High,
38                message,
39                location: SourceLocation {
40                    path: file.path.display().to_string(),
41                    line,
42                    column,
43                },
44                help: Some(
45                    "Replace `x += y` with `x = x.checked_add(y).ok_or(ErrorCode::Overflow)?`, \
46                    or use `saturating_add` if overflow should saturate rather than error."
47                        .to_string(),
48                ),
49            })
50            .collect()
51    }
52}
53
54struct ArithmeticCollector {
55    findings: Vec<(String, usize, usize)>,
56}
57
58impl<'ast> Visit<'ast> for ArithmeticCollector {
59    fn visit_expr_binary(&mut self, node: &'ast ExprBinary) {
60        let left = node.left.to_token_stream().to_string();
61        let right = node.right.to_token_stream().to_string();
62        let loc = node.left.span().start();
63
64        match &node.op {
65            // Compound assignments: +=, -=, *=
66            // Only flag when the target has a field access — loop counters like `i += 1` are skipped.
67            BinOp::AddAssign(_) | BinOp::SubAssign(_) | BinOp::MulAssign(_)
68                if has_field_access(&left) =>
69            {
70                let op = op_symbol(&node.op);
71                self.findings.push((
72                    format!(
73                        "unchecked `{op}` on `{}`; can overflow or underflow in release builds",
74                        left.trim()
75                    ),
76                    loc.line,
77                    loc.column + 1,
78                ));
79            }
80            // Pure arithmetic: +, -, *
81            // Flag when at least one operand is a field access (account data involved).
82            BinOp::Add(_) | BinOp::Sub(_) | BinOp::Mul(_)
83                if has_field_access(&left) || has_field_access(&right) =>
84            {
85                let op = op_symbol(&node.op);
86                self.findings.push((
87                    format!(
88                        "unchecked `{op}` involving account field; can overflow or underflow in release builds"
89                    ),
90                    loc.line,
91                    loc.column + 1,
92                ));
93            }
94            _ => {}
95        }
96
97        visit::visit_expr_binary(self, node);
98    }
99}
100
101fn has_field_access(expr: &str) -> bool {
102    let trimmed = expr.trim();
103    // Exclude float literals like "1.0" or "3.14_f64" that contain dots but are not field accesses.
104    if trimmed.chars().all(|c| {
105        c.is_ascii_digit()
106            || c == '.'
107            || c == '_'
108            || c.is_ascii_alphabetic() && c.is_ascii_lowercase() && !matches!(c, 'a'..='f')
109    }) && !trimmed.contains("::")
110        && !trimmed.contains('(')
111    {
112        let without_suffix = trimmed.trim_end_matches(|c: char| c.is_ascii_alphabetic());
113        if without_suffix
114            .chars()
115            .all(|c| c.is_ascii_digit() || c == '.' || c == '_')
116        {
117            return false;
118        }
119    }
120    trimmed.contains('.')
121}
122
123fn op_symbol(op: &BinOp) -> &'static str {
124    match op {
125        BinOp::Add(_) | BinOp::AddAssign(_) => "+",
126        BinOp::Sub(_) | BinOp::SubAssign(_) => "-",
127        BinOp::Mul(_) | BinOp::MulAssign(_) => "*",
128        _ => unreachable!(),
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::rules::RuleContext;
136    use std::path::PathBuf;
137
138    fn parse_file(source: &str) -> ParsedFile {
139        ParsedFile {
140            path: PathBuf::from("src/lib.rs"),
141            source: source.to_string(),
142            syntax: syn::parse_file(source).expect("source should parse"),
143        }
144    }
145
146    #[test]
147    fn flags_compound_add_assign_on_account_field() {
148        let file = parse_file(
149            r#"
150            use anchor_lang::prelude::*;
151            pub fn handler(ctx: Context<Deposit>, amount: u64) -> Result<()> {
152                ctx.accounts.vault.balance += amount;
153                Ok(())
154            }
155        "#,
156        );
157        let rule = UncheckedArithmeticRule;
158        let findings = rule.match_file(
159            &file,
160            &RuleContext {
161                files: std::slice::from_ref(&file),
162            },
163        );
164        assert_eq!(findings.len(), 1);
165        assert_eq!(findings[0].rule_id, "SW005");
166        assert!(findings[0].message.contains("+"));
167    }
168
169    #[test]
170    fn flags_sub_assign_and_mul_on_account_field() {
171        let file = parse_file(
172            r#"
173            use anchor_lang::prelude::*;
174            pub fn handler(ctx: Context<Transfer>, amount: u64, rate: u64) -> Result<()> {
175                ctx.accounts.vault.balance -= amount;
176                let fee = ctx.accounts.vault.balance * rate;
177                Ok(())
178            }
179        "#,
180        );
181        let rule = UncheckedArithmeticRule;
182        let findings = rule.match_file(
183            &file,
184            &RuleContext {
185                files: std::slice::from_ref(&file),
186            },
187        );
188        assert_eq!(findings.len(), 2);
189        assert!(findings.iter().all(|f| f.rule_id == "SW005"));
190    }
191
192    #[test]
193    fn does_not_flag_loop_counter_or_local_arithmetic() {
194        let file = parse_file(
195            r#"
196            use anchor_lang::prelude::*;
197            pub fn handler(_ctx: Context<Example>, amount: u64, fee: u64) -> Result<()> {
198                let mut i = 0u64;
199                i += 1;
200                let total = amount + fee;
201                Ok(())
202            }
203        "#,
204        );
205        let rule = UncheckedArithmeticRule;
206        let findings = rule.match_file(
207            &file,
208            &RuleContext {
209                files: std::slice::from_ref(&file),
210            },
211        );
212        assert!(findings.is_empty());
213    }
214
215    #[test]
216    fn does_not_flag_checked_arithmetic() {
217        let file = parse_file(
218            r#"
219            use anchor_lang::prelude::*;
220            pub fn handler(ctx: Context<Deposit>, amount: u64) -> Result<()> {
221                ctx.accounts.vault.balance = ctx.accounts.vault.balance
222                    .checked_add(amount)
223                    .ok_or(ErrorCode::Overflow)?;
224                Ok(())
225            }
226            #[error_code]
227            pub enum ErrorCode { Overflow }
228        "#,
229        );
230        let rule = UncheckedArithmeticRule;
231        let findings = rule.match_file(
232            &file,
233            &RuleContext {
234                files: std::slice::from_ref(&file),
235            },
236        );
237        assert!(findings.is_empty());
238    }
239
240    #[test]
241    fn does_not_flag_saturating_arithmetic() {
242        let file = parse_file(
243            r#"
244            use anchor_lang::prelude::*;
245            pub fn handler(ctx: Context<Deposit>, amount: u64) -> Result<()> {
246                ctx.accounts.vault.balance = ctx.accounts.vault.balance.saturating_add(amount);
247                Ok(())
248            }
249        "#,
250        );
251        let rule = UncheckedArithmeticRule;
252        let findings = rule.match_file(
253            &file,
254            &RuleContext {
255                files: std::slice::from_ref(&file),
256            },
257        );
258        assert!(findings.is_empty());
259    }
260}