Skip to main content

sentio_core/
instruction_analysis.rs

1use crate::ast_index::{span_of, AstSpan};
2use quote::ToTokens;
3use serde::Serialize;
4use std::collections::HashMap;
5use syn::parse::Parser;
6use syn::spanned::Spanned;
7use syn::visit::{self, Visit};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
10pub struct InstructionIndex {
11    pub functions: Vec<InstructionFunction>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
15pub struct InstructionFunction {
16    pub name: String,
17    pub qualified_name: String,
18    pub span: AstSpan,
19    pub guards: Vec<GuardEvidence>,
20    pub calls: Vec<CallEvidence>,
21    pub writes: Vec<WriteEvidence>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
25pub struct GuardEvidence {
26    pub kind: GuardKind,
27    pub expression: String,
28    pub span: AstSpan,
29    pub order: usize,
30    pub references_owner: bool,
31    pub references_signer: bool,
32    pub references_key: bool,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
36#[serde(rename_all = "snake_case")]
37pub enum GuardKind {
38    IfCondition,
39    RequireMacro,
40    AssertMacro,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
44pub struct CallEvidence {
45    pub kind: CallKind,
46    pub callee: String,
47    pub span: AstSpan,
48    pub order: usize,
49    /// Account names extracted from the CpiContext struct for this CPI call.
50    /// Empty when the CPI accounts could not be resolved (raw invoke, unknown binding).
51    pub cpi_account_names: Vec<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
55#[serde(rename_all = "snake_case")]
56pub enum CallKind {
57    Deserialization,
58    Cpi,
59    Reload,
60    Other,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
64pub struct WriteEvidence {
65    pub target: String,
66    pub span: AstSpan,
67    pub order: usize,
68}
69
70pub fn collect_instruction_index(file: &syn::File) -> InstructionIndex {
71    let mut collector = InstructionCollector::default();
72    collector.visit_file(file);
73    InstructionIndex {
74        functions: collector.functions,
75    }
76}
77
78#[derive(Default)]
79struct InstructionCollector {
80    functions: Vec<InstructionFunction>,
81    module_stack: Vec<String>,
82    impl_stack: Vec<String>,
83}
84
85impl<'ast> Visit<'ast> for InstructionCollector {
86    fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
87        self.module_stack.push(node.ident.to_string());
88
89        if let Some((_, items)) = &node.content {
90            for item in items {
91                self.visit_item(item);
92            }
93        }
94
95        self.module_stack.pop();
96    }
97
98    fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
99        self.impl_stack.push(node.self_ty.to_token_stream().to_string());
100        visit::visit_item_impl(self, node);
101        self.impl_stack.pop();
102    }
103
104    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
105        self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
106    }
107
108    fn visit_impl_item_fn(&mut self, node: &'ast syn::ImplItemFn) {
109        self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
110    }
111}
112
113impl InstructionCollector {
114    fn collect_function(&mut self, name: String, span: proc_macro2::Span, block: &syn::Block) {
115        let mut collector = FunctionBodyCollector::default();
116        collector.visit_block(block);
117
118        self.functions.push(InstructionFunction {
119            qualified_name: self.qualified_name(&name),
120            name,
121            span: span_of(span),
122            guards: collector.guards,
123            calls: collector.calls,
124            writes: collector.writes,
125        });
126    }
127
128    fn qualified_name(&self, name: &str) -> String {
129        let mut parts = self.module_stack.clone();
130        if let Some(impl_name) = self.impl_stack.last() {
131            parts.push(impl_name.clone());
132        }
133        parts.push(name.to_string());
134        parts.join("::")
135    }
136}
137
138#[derive(Default)]
139struct FunctionBodyCollector {
140    next_order: usize,
141    guards: Vec<GuardEvidence>,
142    calls: Vec<CallEvidence>,
143    writes: Vec<WriteEvidence>,
144    /// Maps local variable names to the account names found in the struct literal they were
145    /// bound to (e.g. `let accounts = Transfer { from: ctx.accounts.vault, ... }` →
146    /// `"accounts" → ["vault", ...]`). Used to resolve CpiContext args by name.
147    let_bindings: HashMap<String, Vec<String>>,
148}
149
150impl FunctionBodyCollector {
151    fn order(&mut self) -> usize {
152        self.next_order += 1;
153        self.next_order
154    }
155
156    fn push_guard(&mut self, kind: GuardKind, expression: syn::Expr) {
157        let features = ExprFeatures::from_expr(&expression);
158        self.push_guard_text(kind, expression.to_token_stream().to_string(), expression.span(), features);
159    }
160
161    fn push_guard_text(
162        &mut self,
163        kind: GuardKind,
164        expression: String,
165        span: proc_macro2::Span,
166        features: ExprFeatures,
167    ) {
168        let order = self.order();
169        self.guards.push(GuardEvidence {
170            kind,
171            expression,
172            span: span_of(span),
173            order,
174            references_owner: features.references_owner,
175            references_signer: features.references_signer,
176            references_key: features.references_key,
177        });
178    }
179
180    fn push_call(&mut self, callee: String, span: proc_macro2::Span, cpi_account_names: Vec<String>) {
181        let order = self.order();
182        self.calls.push(CallEvidence {
183            kind: classify_call_kind(&callee),
184            callee,
185            span: span_of(span),
186            order,
187            cpi_account_names,
188        });
189    }
190
191    /// Walk `expr` and return the account names it refers to, following variable
192    /// bindings recorded in `self.let_bindings`. Handles:
193    /// - struct literals: `Transfer { from: ctx.accounts.X, ... }` → `["X", ...]`
194    /// - `CpiContext::new(prog, accounts_expr)` → recurse on accounts_expr
195    /// - variable paths: look up in `let_bindings`
196    /// - references `&expr`: strip and recurse
197    fn extract_account_names_from_expr(&self, expr: &syn::Expr) -> Vec<String> {
198        match expr {
199            syn::Expr::Struct(s) => s
200                .fields
201                .iter()
202                .filter_map(|f| {
203                    let val = normalize_tokens(&f.expr.to_token_stream().to_string());
204                    extract_account_name_from_str(&val)
205                })
206                .collect(),
207            syn::Expr::Call(call) => {
208                let func = normalize_tokens(&call.func.to_token_stream().to_string());
209                if func.contains("CpiContext::new") {
210                    if let Some(accounts_arg) = call.args.iter().nth(1) {
211                        return self.extract_account_names_from_expr(accounts_arg);
212                    }
213                }
214                vec![]
215            }
216            syn::Expr::Path(p) => {
217                let var = p
218                    .path
219                    .segments
220                    .last()
221                    .map(|s| s.ident.to_string())
222                    .unwrap_or_default();
223                self.let_bindings.get(&var).cloned().unwrap_or_default()
224            }
225            syn::Expr::Reference(r) => self.extract_account_names_from_expr(&r.expr),
226            _ => vec![],
227        }
228    }
229
230    fn push_write(&mut self, target: String, span: proc_macro2::Span) {
231        let order = self.order();
232        self.writes.push(WriteEvidence {
233            target,
234            span: span_of(span),
235            order,
236        });
237    }
238
239    fn record_guard_macro(
240        &mut self,
241        path: &syn::Path,
242        tokens: &proc_macro2::TokenStream,
243        span: proc_macro2::Span,
244    ) {
245        if let Some(kind) = classify_guard_macro(path) {
246            if let Some((expression, features)) = macro_guard_payload(path, tokens) {
247                self.push_guard_text(kind, expression, span, features);
248            }
249        }
250    }
251}
252
253impl<'ast> Visit<'ast> for FunctionBodyCollector {
254    fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
255        if let syn::Stmt::Macro(stmt) = node {
256            self.record_guard_macro(&stmt.mac.path, &stmt.mac.tokens, stmt.mac.span());
257        }
258
259        visit::visit_stmt(self, node);
260    }
261
262    fn visit_expr_if(&mut self, node: &'ast syn::ExprIf) {
263        self.push_guard(GuardKind::IfCondition, (*node.cond).clone());
264        visit::visit_expr_if(self, node);
265    }
266
267    fn visit_expr_macro(&mut self, node: &'ast syn::ExprMacro) {
268        self.record_guard_macro(&node.mac.path, &node.mac.tokens, node.span());
269
270        visit::visit_expr_macro(self, node);
271    }
272
273    fn visit_local(&mut self, node: &'ast syn::Local) {
274        if let (Some(init), Some(var_name)) = (&node.init, get_simple_pat_ident(&node.pat)) {
275            let names = self.extract_account_names_from_expr(&init.expr);
276            if !names.is_empty() {
277                self.let_bindings.insert(var_name, names);
278            }
279        }
280        visit::visit_local(self, node);
281    }
282
283    fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
284        let callee = normalize_tokens(&node.func.to_token_stream().to_string());
285        let cpi_account_names = if classify_call_kind(&callee) == CallKind::Cpi {
286            let mut found = vec![];
287            for arg in &node.args {
288                let names = self.extract_account_names_from_expr(arg);
289                if !names.is_empty() {
290                    found = names;
291                    break;
292                }
293            }
294            found
295        } else {
296            vec![]
297        };
298        self.push_call(callee, node.span(), cpi_account_names);
299        visit::visit_expr_call(self, node);
300    }
301
302    fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
303        let receiver = normalize_tokens(&node.receiver.to_token_stream().to_string());
304        let callee = format!("{receiver}.{}", node.method);
305        self.push_call(callee, node.span(), vec![]);
306        visit::visit_expr_method_call(self, node);
307    }
308
309    fn visit_expr_assign(&mut self, node: &'ast syn::ExprAssign) {
310        self.push_write(normalize_tokens(&node.left.to_token_stream().to_string()), node.span());
311        visit::visit_expr_assign(self, node);
312    }
313
314    fn visit_expr_binary(&mut self, node: &'ast syn::ExprBinary) {
315        if is_assign_op(&node.op) {
316            self.push_write(normalize_tokens(&node.left.to_token_stream().to_string()), node.span());
317        }
318
319        visit::visit_expr_binary(self, node);
320    }
321}
322
323#[derive(Default)]
324struct ExprFeatures {
325    references_owner: bool,
326    references_signer: bool,
327    references_key: bool,
328}
329
330impl ExprFeatures {
331    fn from_expr(expr: &syn::Expr) -> Self {
332        let mut collector = ExprFeatureCollector::default();
333        collector.visit_expr(expr);
334        collector.features
335    }
336
337    fn merge(self, other: Self) -> Self {
338        Self {
339            references_owner: self.references_owner || other.references_owner,
340            references_signer: self.references_signer || other.references_signer,
341            references_key: self.references_key || other.references_key,
342        }
343    }
344}
345
346#[derive(Default)]
347struct ExprFeatureCollector {
348    features: ExprFeatures,
349}
350
351impl<'ast> Visit<'ast> for ExprFeatureCollector {
352    fn visit_expr_field(&mut self, node: &'ast syn::ExprField) {
353        if let syn::Member::Named(member) = &node.member {
354            self.record_ident(member);
355        }
356        visit::visit_expr_field(self, node);
357    }
358
359    fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
360        self.record_ident(&node.method);
361        visit::visit_expr_method_call(self, node);
362    }
363
364    fn visit_path(&mut self, node: &'ast syn::Path) {
365        for segment in &node.segments {
366            self.record_ident(&segment.ident);
367        }
368        visit::visit_path(self, node);
369    }
370}
371
372impl ExprFeatureCollector {
373    fn record_ident(&mut self, ident: &syn::Ident) {
374        match ident.to_string().as_str() {
375            "owner" => self.features.references_owner = true,
376            "is_signer" | "signer" => self.features.references_signer = true,
377            "key" => self.features.references_key = true,
378            _ => {}
379        }
380    }
381}
382
383fn classify_guard_macro(path: &syn::Path) -> Option<GuardKind> {
384    let ident = path.segments.last()?.ident.to_string();
385    if ident.starts_with("require") {
386        Some(GuardKind::RequireMacro)
387    } else if ident.starts_with("assert") {
388        Some(GuardKind::AssertMacro)
389    } else {
390        None
391    }
392}
393
394fn parse_macro_guard_args(tokens: &proc_macro2::TokenStream) -> Option<Vec<syn::Expr>> {
395    let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
396    let args = parser.parse2(tokens.clone()).ok()?;
397    Some(args.into_iter().collect())
398}
399
400fn macro_guard_payload(path: &syn::Path, tokens: &proc_macro2::TokenStream) -> Option<(String, ExprFeatures)> {
401    let args = parse_macro_guard_args(tokens)?;
402    if args.is_empty() {
403        return None;
404    }
405
406    let ident = path.segments.last()?.ident.to_string();
407    if (ident.ends_with("_eq") || ident == "assert_eq") && args.len() >= 2 {
408        let expression = format!(
409            "{} == {}",
410            args[0].to_token_stream(),
411            args[1].to_token_stream()
412        );
413        let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
414        return Some((expression, features));
415    }
416
417    if (ident.ends_with("_ne") || ident == "assert_ne") && args.len() >= 2 {
418        let expression = format!(
419            "{} != {}",
420            args[0].to_token_stream(),
421            args[1].to_token_stream()
422        );
423        let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
424        return Some((expression, features));
425    }
426
427    let first = args.into_iter().next()?;
428    let features = ExprFeatures::from_expr(&first);
429    Some((first.to_token_stream().to_string(), features))
430}
431
432fn classify_call_kind(callee: &str) -> CallKind {
433    let normalized = normalize_tokens(callee);
434    let lower = normalized.to_lowercase();
435
436    if normalized.ends_with(".reload") || normalized.ends_with("::reload") {
437        return CallKind::Reload;
438    }
439
440    if lower.contains("try_deserialize")
441        || normalized.ends_with("::try_from")
442        || normalized.ends_with("::from_account_info")
443        || normalized.ends_with(".load")
444        || normalized.ends_with(".load_mut")
445    {
446        return CallKind::Deserialization;
447    }
448
449    if normalized == "invoke"
450        || normalized == "invoke_signed"
451        || normalized.ends_with("::invoke")
452        || normalized.ends_with("::invoke_signed")
453        || normalized.contains("CpiContext::new")
454        || normalized.contains("CpiContext::new_with_signer")
455        || normalized.starts_with("token::")
456        || normalized.contains("anchor_spl::token::")
457    {
458        return CallKind::Cpi;
459    }
460
461    CallKind::Other
462}
463
464fn is_assign_op(op: &syn::BinOp) -> bool {
465    matches!(
466        op,
467        syn::BinOp::AddAssign(_)
468            | syn::BinOp::SubAssign(_)
469            | syn::BinOp::MulAssign(_)
470            | syn::BinOp::DivAssign(_)
471            | syn::BinOp::RemAssign(_)
472            | syn::BinOp::BitXorAssign(_)
473            | syn::BinOp::BitAndAssign(_)
474            | syn::BinOp::BitOrAssign(_)
475            | syn::BinOp::ShlAssign(_)
476            | syn::BinOp::ShrAssign(_)
477    )
478}
479
480fn normalize_tokens(tokens: &str) -> String {
481    tokens.split_whitespace().collect()
482}
483
484fn get_simple_pat_ident(pat: &syn::Pat) -> Option<String> {
485    if let syn::Pat::Ident(p) = pat {
486        Some(p.ident.to_string())
487    } else {
488        None
489    }
490}
491
492/// Extract an account name from an expression string by looking for `.accounts.IDENT`.
493/// Returns `None` when the expression doesn't reference `ctx.accounts`.
494fn extract_account_name_from_str(s: &str) -> Option<String> {
495    let pos = s.find(".accounts.")?;
496    let after = &s[pos + ".accounts.".len()..];
497    let ident: String = after
498        .chars()
499        .take_while(|c| c.is_alphanumeric() || *c == '_')
500        .collect();
501    if ident.is_empty() { None } else { Some(ident) }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    fn parse_file(source: &str) -> syn::File {
509        syn::parse_file(source).expect("source should parse")
510    }
511
512    #[test]
513    fn collects_functions_from_modules_and_impls() {
514        let file = parse_file(
515            r#"
516            mod instructions {
517                pub fn process() {}
518            }
519
520            impl Processor {
521                pub fn handle() {}
522            }
523            "#,
524        );
525
526        let index = collect_instruction_index(&file);
527        assert_eq!(index.functions.len(), 2);
528        assert_eq!(index.functions[0].qualified_name, "instructions::process");
529        assert_eq!(index.functions[1].qualified_name, "Processor::handle");
530    }
531
532    #[test]
533    fn models_guards_calls_and_writes_in_order() {
534        let file = parse_file(
535            r#"
536            pub fn process(mut state: Account<'info, Vault>, authority: Signer<'info>) -> Result<()> {
537                if state.owner != authority.key() {
538                    return Err(ErrorCode::Unauthorized.into());
539                }
540
541                require!(authority.is_signer, ErrorCode::Unauthorized);
542
543                let account = Vault::try_deserialize(&mut data)?;
544                invoke_signed(&ix, &accounts, signer_seeds)?;
545                state.reload()?;
546                state.counter += account.amount;
547                state.authority = authority.key();
548                Ok(())
549            }
550            "#,
551        );
552
553        let index = collect_instruction_index(&file);
554        assert_eq!(index.functions.len(), 1);
555
556        let function = &index.functions[0];
557        assert_eq!(function.qualified_name, "process");
558        assert_eq!(function.guards.len(), 2);
559        assert!(function.guards.iter().any(|guard| guard.references_owner));
560        assert!(function.guards.iter().any(|guard| guard.references_signer));
561        assert!(function.guards.iter().any(|guard| guard.references_key));
562
563        assert!(function.calls.iter().any(|call| {
564            call.kind == CallKind::Deserialization
565                && call.callee.contains("Vault::try_deserialize")
566        }));
567        assert!(function.calls.iter().any(|call| {
568            call.kind == CallKind::Cpi && call.callee.contains("invoke_signed")
569        }));
570        assert!(function.calls.iter().any(|call| {
571            call.kind == CallKind::Reload && call.callee.contains("state.reload")
572        }));
573
574        assert!(function
575            .writes
576            .iter()
577            .any(|write| write.target == "state.counter"));
578        assert!(function
579            .writes
580            .iter()
581            .any(|write| write.target == "state.authority"));
582
583        let cpi_order = function
584            .calls
585            .iter()
586            .find(|call| call.kind == CallKind::Cpi)
587            .expect("cpi call should be recorded")
588            .order;
589        let reload_order = function
590            .calls
591            .iter()
592            .find(|call| call.kind == CallKind::Reload)
593            .expect("reload call should be recorded")
594            .order;
595        let write_order = function
596            .writes
597            .iter()
598            .find(|write| write.target == "state.counter")
599            .expect("counter write should be recorded")
600            .order;
601
602        assert!(cpi_order < reload_order);
603        assert!(reload_order < write_order);
604    }
605
606    #[test]
607    fn models_eq_style_guard_macros_with_both_operands() {
608        let file = parse_file(
609            r#"
610            pub fn process(account: AccountInfo<'info>, authority: Signer<'info>) -> Result<()> {
611                require_keys_eq!(account.owner, authority.key(), ErrorCode::Unauthorized);
612                Ok(())
613            }
614            "#,
615        );
616
617        let index = collect_instruction_index(&file);
618        let function = &index.functions[0];
619        let guard = function
620            .guards
621            .iter()
622            .find(|guard| guard.kind == GuardKind::RequireMacro)
623            .expect("require macro guard should be recorded");
624
625        assert!(guard.expression.contains("=="));
626        assert!(guard.references_owner);
627        assert!(guard.references_key);
628    }
629}