Skip to main content

sentio_core/rules/anchor/
arbitrary_cpi.rs

1use crate::finding::SourceLocation;
2use crate::instruction_analysis::{collect_instruction_index, CallKind};
3use crate::rules::{Rule, RuleContext, RuleMatch, RuleMetadata, RuleSeverity};
4use crate::syntax::ParsedFile;
5
6#[derive(Debug, Default)]
7pub struct ArbitraryCpiRule;
8
9impl Rule for ArbitraryCpiRule {
10    fn metadata(&self) -> &RuleMetadata {
11        static METADATA: RuleMetadata = RuleMetadata {
12            id: "SW003",
13            title: "Arbitrary CPI target",
14            severity: RuleSeverity::Critical,
15            description: "Detects CPI calls where no key or program ID check precedes the invocation, allowing an attacker to supply a malicious program as the CPI target.",
16            fix_guidance: "Verify the target program key before invoking (e.g. require!(cpi_program.key() == expected::ID, ...)) or use Program<'info, T> so Anchor validates the program ID automatically.",
17        };
18        &METADATA
19    }
20
21    fn match_file(&self, file: &ParsedFile, _ctx: &RuleContext<'_>) -> Vec<RuleMatch> {
22        let index = collect_instruction_index(&file.syntax);
23        let mut findings = Vec::new();
24
25        for function in &index.functions {
26            // Only flag raw invoke/invoke_signed — Anchor CpiContext calls are validated
27            // at the account struct level via Program<'info, T> (covered by SW020).
28            let cpi_calls: Vec<_> = function
29                .calls
30                .iter()
31                .filter(|c| c.kind == CallKind::Cpi && is_raw_invoke(&c.callee))
32                .collect();
33
34            if cpi_calls.is_empty() {
35                continue;
36            }
37
38            for cpi_call in cpi_calls {
39                // Check if any key-referencing guard appears before this CPI call.
40                let guarded = function
41                    .guards
42                    .iter()
43                    .any(|g| g.references_key && g.order < cpi_call.order);
44
45                if !guarded {
46                    findings.push(RuleMatch {
47                        rule_id: "SW003",
48                        severity: RuleSeverity::Critical,
49                        message: format!(
50                            "CPI call `{}` in `{}` has no preceding program key validation.",
51                            cpi_call.callee, function.name
52                        ),
53                        location: SourceLocation {
54                            path: file.path.display().to_string(),
55                            line: cpi_call.span.start_line,
56                            column: cpi_call.span.start_column,
57                        },
58                        help: Some(
59                            "Add require!(program.key() == expected::ID, ...) before the CPI, or use Program<'info, T> to enforce program ID validation at the account level."
60                                .to_string(),
61                        ),
62                    });
63                }
64            }
65        }
66
67        findings
68    }
69}
70
71fn is_raw_invoke(callee: &str) -> bool {
72    let n = callee.trim();
73    n == "invoke"
74        || n == "invoke_signed"
75        || n == "invoke_unchecked"
76        || n.ends_with("::invoke")
77        || n.ends_with("::invoke_signed")
78        || n.ends_with("::invoke_unchecked")
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use crate::rules::RuleContext;
85    use crate::syntax::ParsedFile;
86    use std::path::PathBuf;
87
88    fn parse_file(source: &str) -> ParsedFile {
89        ParsedFile {
90            path: PathBuf::from("src/lib.rs"),
91            source: source.to_string(),
92            syntax: syn::parse_file(source).expect("source should parse"),
93        }
94    }
95
96    #[test]
97    fn flags_cpi_without_key_check() {
98        let file = parse_file(
99            r#"
100            use anchor_lang::prelude::*;
101            use solana_program::program::invoke;
102
103            pub fn handler(ctx: Context<Example>) -> Result<()> {
104                invoke(
105                    &instruction,
106                    &[ctx.accounts.target_program.to_account_info()],
107                )?;
108                Ok(())
109            }
110        "#,
111        );
112
113        let rule = ArbitraryCpiRule;
114        let findings = rule.match_file(
115            &file,
116            &RuleContext {
117                files: std::slice::from_ref(&file),
118            },
119        );
120        assert_eq!(findings.len(), 1);
121        assert_eq!(findings[0].rule_id, "SW003");
122    }
123
124    #[test]
125    fn does_not_flag_cpi_with_key_check_before() {
126        let file = parse_file(
127            r#"
128            use anchor_lang::prelude::*;
129            use solana_program::program::invoke;
130
131            pub fn handler(ctx: Context<Example>) -> Result<()> {
132                require!(
133                    ctx.accounts.target_program.key() == &expected_program::ID,
134                    ErrorCode::InvalidProgram
135                );
136                invoke(
137                    &instruction,
138                    &[ctx.accounts.target_program.to_account_info()],
139                )?;
140                Ok(())
141            }
142        "#,
143        );
144
145        let rule = ArbitraryCpiRule;
146        let findings = rule.match_file(
147            &file,
148            &RuleContext {
149                files: std::slice::from_ref(&file),
150            },
151        );
152        assert!(findings.is_empty());
153    }
154
155    #[test]
156    fn does_not_flag_function_with_no_cpi() {
157        let file = parse_file(
158            r#"
159            use anchor_lang::prelude::*;
160
161            pub fn handler(ctx: Context<Example>) -> Result<()> {
162                ctx.accounts.vault.balance = 100;
163                Ok(())
164            }
165        "#,
166        );
167
168        let rule = ArbitraryCpiRule;
169        let findings = rule.match_file(
170            &file,
171            &RuleContext {
172                files: std::slice::from_ref(&file),
173            },
174        );
175        assert!(findings.is_empty());
176    }
177}