Skip to main content

sentio_core/rules/anchor/
unchecked_arithmetic.rs

1use crate::finding::SourceLocation;
2use crate::rules::{Rule, RuleContext, RuleMatch, RuleMetadata, RuleSeverity};
3use crate::syntax::ParsedFile;
4use quote::ToTokens;
5use syn::spanned::Spanned;
6use syn::visit::{self, Visit};
7use syn::{BinOp, ExprBinary};
8
9#[derive(Debug, Default)]
10pub struct UncheckedArithmeticRule;
11
12impl Rule for UncheckedArithmeticRule {
13    fn metadata(&self) -> &RuleMetadata {
14        static METADATA: RuleMetadata = RuleMetadata {
15            id: "SW005",
16            title: "Unchecked arithmetic",
17            severity: RuleSeverity::High,
18            description: "Detects arithmetic operations (+, -, *) on account data that can \
19                silently overflow or underflow in release builds, where Rust wraps by default.",
20            fix_guidance: "Use checked_add(), checked_sub(), or checked_mul() and propagate \
21                the error with ?, or use saturating_add()/saturating_sub() when wrapping is intentional.",
22        };
23        &METADATA
24    }
25
26    fn match_file(&self, file: &ParsedFile, _ctx: &RuleContext<'_>) -> Vec<RuleMatch> {
27        let mut collector = ArithmeticCollector { findings: Vec::new() };
28        visit::visit_file(&mut collector, &file.syntax);
29
30        collector
31            .findings
32            .into_iter()
33            .map(|(message, line, column)| RuleMatch {
34                rule_id: "SW005",
35                severity: RuleSeverity::High,
36                message,
37                location: SourceLocation {
38                    path: file.path.display().to_string(),
39                    line,
40                    column,
41                },
42                help: Some(
43                    "Replace `x += y` with `x = x.checked_add(y).ok_or(ErrorCode::Overflow)?`, \
44                    or use `saturating_add` if overflow should saturate rather than error."
45                        .to_string(),
46                ),
47            })
48            .collect()
49    }
50}
51
52struct ArithmeticCollector {
53    findings: Vec<(String, usize, usize)>,
54}
55
56impl<'ast> Visit<'ast> for ArithmeticCollector {
57    fn visit_expr_binary(&mut self, node: &'ast ExprBinary) {
58        let left = node.left.to_token_stream().to_string();
59        let right = node.right.to_token_stream().to_string();
60        let loc = node.left.span().start();
61
62        match &node.op {
63            // Compound assignments: +=, -=, *=
64            // Only flag when the target has a field access — loop counters like `i += 1` are skipped.
65            BinOp::AddAssign(_) | BinOp::SubAssign(_) | BinOp::MulAssign(_) => {
66                if has_field_access(&left) {
67                    let op = op_symbol(&node.op);
68                    self.findings.push((
69                        format!(
70                            "unchecked `{op}` on `{}`; can overflow or underflow in release builds",
71                            left.trim()
72                        ),
73                        loc.line,
74                        loc.column + 1,
75                    ));
76                }
77            }
78            // Pure arithmetic: +, -, *
79            // Flag when at least one operand is a field access (account data involved).
80            BinOp::Add(_) | BinOp::Sub(_) | BinOp::Mul(_) => {
81                if has_field_access(&left) || has_field_access(&right) {
82                    let op = op_symbol(&node.op);
83                    self.findings.push((
84                        format!(
85                            "unchecked `{op}` involving account field; can overflow or underflow in release builds"
86                        ),
87                        loc.line,
88                        loc.column + 1,
89                    ));
90                }
91            }
92            _ => {}
93        }
94
95        visit::visit_expr_binary(self, node);
96    }
97}
98
99fn has_field_access(expr: &str) -> bool {
100    let trimmed = expr.trim();
101    // Exclude float literals like "1.0" or "3.14_f64" that contain dots but are not field accesses.
102    if trimmed.chars().all(|c| c.is_ascii_digit() || c == '.' || c == '_' || c.is_ascii_alphabetic() && c.is_ascii_lowercase() && !matches!(c, 'a'..='f'))
103        && !trimmed.contains("::")
104        && !trimmed.contains('(')
105    {
106        let without_suffix = trimmed.trim_end_matches(|c: char| c.is_ascii_alphabetic());
107        if without_suffix.chars().all(|c| c.is_ascii_digit() || c == '.' || c == '_') {
108            return false;
109        }
110    }
111    trimmed.contains('.')
112}
113
114fn op_symbol(op: &BinOp) -> &'static str {
115    match op {
116        BinOp::Add(_) | BinOp::AddAssign(_) => "+",
117        BinOp::Sub(_) | BinOp::SubAssign(_) => "-",
118        BinOp::Mul(_) | BinOp::MulAssign(_) => "*",
119        _ => unreachable!(),
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::rules::RuleContext;
127    use std::path::PathBuf;
128
129    fn parse_file(source: &str) -> ParsedFile {
130        ParsedFile {
131            path: PathBuf::from("src/lib.rs"),
132            source: source.to_string(),
133            syntax: syn::parse_file(source).expect("source should parse"),
134        }
135    }
136
137    #[test]
138    fn flags_compound_add_assign_on_account_field() {
139        let file = parse_file(r#"
140            use anchor_lang::prelude::*;
141            pub fn handler(ctx: Context<Deposit>, amount: u64) -> Result<()> {
142                ctx.accounts.vault.balance += amount;
143                Ok(())
144            }
145        "#);
146        let rule = UncheckedArithmeticRule;
147        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
148        assert_eq!(findings.len(), 1);
149        assert_eq!(findings[0].rule_id, "SW005");
150        assert!(findings[0].message.contains("+"));
151    }
152
153    #[test]
154    fn flags_sub_assign_and_mul_on_account_field() {
155        let file = parse_file(r#"
156            use anchor_lang::prelude::*;
157            pub fn handler(ctx: Context<Transfer>, amount: u64, rate: u64) -> Result<()> {
158                ctx.accounts.vault.balance -= amount;
159                let fee = ctx.accounts.vault.balance * rate;
160                Ok(())
161            }
162        "#);
163        let rule = UncheckedArithmeticRule;
164        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
165        assert_eq!(findings.len(), 2);
166        assert!(findings.iter().all(|f| f.rule_id == "SW005"));
167    }
168
169    #[test]
170    fn does_not_flag_loop_counter_or_local_arithmetic() {
171        let file = parse_file(r#"
172            use anchor_lang::prelude::*;
173            pub fn handler(_ctx: Context<Example>, amount: u64, fee: u64) -> Result<()> {
174                let mut i = 0u64;
175                i += 1;
176                let total = amount + fee;
177                Ok(())
178            }
179        "#);
180        let rule = UncheckedArithmeticRule;
181        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
182        assert!(findings.is_empty());
183    }
184
185    #[test]
186    fn does_not_flag_checked_arithmetic() {
187        let file = parse_file(r#"
188            use anchor_lang::prelude::*;
189            pub fn handler(ctx: Context<Deposit>, amount: u64) -> Result<()> {
190                ctx.accounts.vault.balance = ctx.accounts.vault.balance
191                    .checked_add(amount)
192                    .ok_or(ErrorCode::Overflow)?;
193                Ok(())
194            }
195            #[error_code]
196            pub enum ErrorCode { Overflow }
197        "#);
198        let rule = UncheckedArithmeticRule;
199        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
200        assert!(findings.is_empty());
201    }
202
203    #[test]
204    fn does_not_flag_saturating_arithmetic() {
205        let file = parse_file(r#"
206            use anchor_lang::prelude::*;
207            pub fn handler(ctx: Context<Deposit>, amount: u64) -> Result<()> {
208                ctx.accounts.vault.balance = ctx.accounts.vault.balance.saturating_add(amount);
209                Ok(())
210            }
211        "#);
212        let rule = UncheckedArithmeticRule;
213        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
214        assert!(findings.is_empty());
215    }
216}