sentio_core/rules/anchor/
arbitrary_cpi.rs1use crate::finding::SourceLocation;
2use crate::instruction_analysis::{collect_instruction_index, CallKind};
3use crate::rules::{Rule, RuleContext, RuleMatch, RuleMetadata, RuleSeverity};
4use crate::syntax::ParsedFile;
5
6#[derive(Debug, Default)]
7pub struct ArbitraryCpiRule;
8
9impl Rule for ArbitraryCpiRule {
10 fn metadata(&self) -> &RuleMetadata {
11 static METADATA: RuleMetadata = RuleMetadata {
12 id: "SW003",
13 title: "Arbitrary CPI target",
14 severity: RuleSeverity::Critical,
15 description: "Detects CPI calls where no key or program ID check precedes the invocation, allowing an attacker to supply a malicious program as the CPI target.",
16 fix_guidance: "Verify the target program key before invoking (e.g. require!(cpi_program.key() == expected::ID, ...)) or use Program<'info, T> so Anchor validates the program ID automatically.",
17 };
18 &METADATA
19 }
20
21 fn match_file(&self, file: &ParsedFile, _ctx: &RuleContext<'_>) -> Vec<RuleMatch> {
22 let index = collect_instruction_index(&file.syntax);
23 let mut findings = Vec::new();
24
25 for function in &index.functions {
26 let cpi_calls: Vec<_> = function
29 .calls
30 .iter()
31 .filter(|c| c.kind == CallKind::Cpi && is_raw_invoke(&c.callee))
32 .collect();
33
34 if cpi_calls.is_empty() {
35 continue;
36 }
37
38 for cpi_call in cpi_calls {
39 let guarded = function
41 .guards
42 .iter()
43 .any(|g| g.references_key && g.order < cpi_call.order);
44
45 if !guarded {
46 findings.push(RuleMatch {
47 rule_id: "SW003",
48 severity: RuleSeverity::Critical,
49 message: format!(
50 "CPI call `{}` in `{}` has no preceding program key validation.",
51 cpi_call.callee, function.name
52 ),
53 location: SourceLocation {
54 path: file.path.display().to_string(),
55 line: cpi_call.span.start_line,
56 column: cpi_call.span.start_column,
57 },
58 help: Some(
59 "Add require!(program.key() == expected::ID, ...) before the CPI, or use Program<'info, T> to enforce program ID validation at the account level."
60 .to_string(),
61 ),
62 });
63 }
64 }
65 }
66
67 findings
68 }
69}
70
71fn is_raw_invoke(callee: &str) -> bool {
72 let n = callee.trim();
73 n == "invoke"
74 || n == "invoke_signed"
75 || n == "invoke_unchecked"
76 || n.ends_with("::invoke")
77 || n.ends_with("::invoke_signed")
78 || n.ends_with("::invoke_unchecked")
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use crate::rules::RuleContext;
85 use crate::syntax::ParsedFile;
86 use std::path::PathBuf;
87
88 fn parse_file(source: &str) -> ParsedFile {
89 ParsedFile {
90 path: PathBuf::from("src/lib.rs"),
91 source: source.to_string(),
92 syntax: syn::parse_file(source).expect("source should parse"),
93 }
94 }
95
96 #[test]
97 fn flags_cpi_without_key_check() {
98 let file = parse_file(r#"
99 use anchor_lang::prelude::*;
100 use solana_program::program::invoke;
101
102 pub fn handler(ctx: Context<Example>) -> Result<()> {
103 invoke(
104 &instruction,
105 &[ctx.accounts.target_program.to_account_info()],
106 )?;
107 Ok(())
108 }
109 "#);
110
111 let rule = ArbitraryCpiRule;
112 let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
113 assert_eq!(findings.len(), 1);
114 assert_eq!(findings[0].rule_id, "SW003");
115 }
116
117 #[test]
118 fn does_not_flag_cpi_with_key_check_before() {
119 let file = parse_file(r#"
120 use anchor_lang::prelude::*;
121 use solana_program::program::invoke;
122
123 pub fn handler(ctx: Context<Example>) -> Result<()> {
124 require!(
125 ctx.accounts.target_program.key() == &expected_program::ID,
126 ErrorCode::InvalidProgram
127 );
128 invoke(
129 &instruction,
130 &[ctx.accounts.target_program.to_account_info()],
131 )?;
132 Ok(())
133 }
134 "#);
135
136 let rule = ArbitraryCpiRule;
137 let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
138 assert!(findings.is_empty());
139 }
140
141 #[test]
142 fn does_not_flag_function_with_no_cpi() {
143 let file = parse_file(r#"
144 use anchor_lang::prelude::*;
145
146 pub fn handler(ctx: Context<Example>) -> Result<()> {
147 ctx.accounts.vault.balance = 100;
148 Ok(())
149 }
150 "#);
151
152 let rule = ArbitraryCpiRule;
153 let findings = rule.match_file(&file, &RuleContext { files: std::slice::from_ref(&file) });
154 assert!(findings.is_empty());
155 }
156}