Skip to main content

sentio_core/rules/anchor/
arbitrary_cpi.rs

1use crate::finding::SourceLocation;
2use crate::instruction_analysis::{collect_instruction_index, CallKind};
3use crate::rules::{Rule, RuleContext, RuleMatch, RuleMetadata, RuleSeverity};
4use crate::syntax::ParsedFile;
5
6#[derive(Debug, Default)]
7pub struct ArbitraryCpiRule;
8
9impl Rule for ArbitraryCpiRule {
10    fn metadata(&self) -> &RuleMetadata {
11        static METADATA: RuleMetadata = RuleMetadata {
12            id: "SW003",
13            title: "Arbitrary CPI target",
14            severity: RuleSeverity::Critical,
15            description: "Detects CPI calls where no key or program ID check precedes the invocation, allowing an attacker to supply a malicious program as the CPI target.",
16            fix_guidance: "Verify the target program key before invoking (e.g. require!(cpi_program.key() == expected::ID, ...)) or use Program<'info, T> so Anchor validates the program ID automatically.",
17        };
18        &METADATA
19    }
20
21    fn match_file(&self, file: &ParsedFile, _ctx: &RuleContext<'_>) -> Vec<RuleMatch> {
22        let index = collect_instruction_index(&file.syntax);
23        let mut findings = Vec::new();
24
25        for function in &index.functions {
26            // Only flag raw invoke/invoke_signed — Anchor CpiContext calls are validated
27            // at the account struct level via Program<'info, T> (covered by SW020).
28            let cpi_calls: Vec<_> = function
29                .calls
30                .iter()
31                .filter(|c| c.kind == CallKind::Cpi && is_raw_invoke(&c.callee))
32                .collect();
33
34            if cpi_calls.is_empty() {
35                continue;
36            }
37
38            for cpi_call in cpi_calls {
39                // Check if any key-referencing guard appears before this CPI call.
40                let guarded = function
41                    .guards
42                    .iter()
43                    .any(|g| g.references_key && g.order < cpi_call.order);
44
45                if !guarded {
46                    findings.push(RuleMatch {
47                        rule_id: "SW003",
48                        severity: RuleSeverity::Critical,
49                        message: format!(
50                            "CPI call `{}` in `{}` has no preceding program key validation.",
51                            cpi_call.callee, function.name
52                        ),
53                        location: SourceLocation {
54                            path: file.path.display().to_string(),
55                            line: cpi_call.span.start_line,
56                            column: cpi_call.span.start_column,
57                        },
58                        help: Some(
59                            "Add require!(program.key() == expected::ID, ...) before the CPI, or use Program<'info, T> to enforce program ID validation at the account level."
60                                .to_string(),
61                        ),
62                    });
63                }
64            }
65        }
66
67        findings
68    }
69}
70
71fn is_raw_invoke(callee: &str) -> bool {
72    let n = callee.trim();
73    n == "invoke"
74        || n == "invoke_signed"
75        || n == "invoke_unchecked"
76        || n.ends_with("::invoke")
77        || n.ends_with("::invoke_signed")
78        || n.ends_with("::invoke_unchecked")
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::rules::RuleContext;
85    use crate::syntax::ParsedFile;
86    use std::path::PathBuf;
87
88    fn parse_file(source: &str) -> ParsedFile {
89        ParsedFile {
90            path: PathBuf::from("src/lib.rs"),
91            source: source.to_string(),
92            syntax: syn::parse_file(source).expect("source should parse"),
93        }
94    }
95
96    #[test]
97    fn flags_cpi_without_key_check() {
98        let file = parse_file(r#"
99            use anchor_lang::prelude::*;
100            use solana_program::program::invoke;
101
102            pub fn handler(ctx: Context<Example>) -> Result<()> {
103                invoke(
104                    &instruction,
105                    &[ctx.accounts.target_program.to_account_info()],
106                )?;
107                Ok(())
108            }
109        "#);
110
111        let rule = ArbitraryCpiRule;
112        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
113        assert_eq!(findings.len(), 1);
114        assert_eq!(findings[0].rule_id, "SW003");
115    }
116
117    #[test]
118    fn does_not_flag_cpi_with_key_check_before() {
119        let file = parse_file(r#"
120            use anchor_lang::prelude::*;
121            use solana_program::program::invoke;
122
123            pub fn handler(ctx: Context<Example>) -> Result<()> {
124                require!(
125                    ctx.accounts.target_program.key() == &expected_program::ID,
126                    ErrorCode::InvalidProgram
127                );
128                invoke(
129                    &instruction,
130                    &[ctx.accounts.target_program.to_account_info()],
131                )?;
132                Ok(())
133            }
134        "#);
135
136        let rule = ArbitraryCpiRule;
137        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
138        assert!(findings.is_empty());
139    }
140
141    #[test]
142    fn does_not_flag_function_with_no_cpi() {
143        let file = parse_file(r#"
144            use anchor_lang::prelude::*;
145
146            pub fn handler(ctx: Context<Example>) -> Result<()> {
147                ctx.accounts.vault.balance = 100;
148                Ok(())
149            }
150        "#);
151
152        let rule = ArbitraryCpiRule;
153        let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
154        assert!(findings.is_empty());
155    }
156}