Skip to main content

sentio_core/
instruction_analysis.rs

1use crate::ast_index::{span_of, AstSpan};
2use quote::ToTokens;
3use serde::Serialize;
4use std::collections::HashMap;
5use syn::parse::Parser;
6use syn::spanned::Spanned;
7use syn::visit::{self, Visit};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
10pub struct InstructionIndex {
11    pub functions: Vec<InstructionFunction>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
15pub struct InstructionFunction {
16    pub name: String,
17    pub qualified_name: String,
18    pub span: AstSpan,
19    pub guards: Vec<GuardEvidence>,
20    pub calls: Vec<CallEvidence>,
21    pub writes: Vec<WriteEvidence>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
25pub struct GuardEvidence {
26    pub kind: GuardKind,
27    pub expression: String,
28    pub span: AstSpan,
29    pub order: usize,
30    pub references_owner: bool,
31    pub references_signer: bool,
32    pub references_key: bool,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
36#[serde(rename_all = "snake_case")]
37pub enum GuardKind {
38    IfCondition,
39    RequireMacro,
40    AssertMacro,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
44pub struct CallEvidence {
45    pub kind: CallKind,
46    pub callee: String,
47    pub span: AstSpan,
48    pub order: usize,
49    /// Account names extracted from the CpiContext struct for this CPI call.
50    /// Empty when the CPI accounts could not be resolved (raw invoke, unknown binding).
51    pub cpi_account_names: Vec<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
55#[serde(rename_all = "snake_case")]
56pub enum CallKind {
57    Deserialization,
58    Cpi,
59    Reload,
60    Other,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
64pub struct WriteEvidence {
65    pub target: String,
66    pub span: AstSpan,
67    pub order: usize,
68}
69
70pub fn collect_instruction_index(file: &syn::File) -> InstructionIndex {
71    let mut collector = InstructionCollector::default();
72    collector.visit_file(file);
73    InstructionIndex {
74        functions: collector.functions,
75    }
76}
77
78#[derive(Default)]
79struct InstructionCollector {
80    functions: Vec<InstructionFunction>,
81    module_stack: Vec<String>,
82    impl_stack: Vec<String>,
83}
84
85impl<'ast> Visit<'ast> for InstructionCollector {
86    fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
87        self.module_stack.push(node.ident.to_string());
88
89        if let Some((_, items)) = &node.content {
90            for item in items {
91                self.visit_item(item);
92            }
93        }
94
95        self.module_stack.pop();
96    }
97
98    fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
99        self.impl_stack
100            .push(node.self_ty.to_token_stream().to_string());
101        visit::visit_item_impl(self, node);
102        self.impl_stack.pop();
103    }
104
105    fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
106        self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
107    }
108
109    fn visit_impl_item_fn(&mut self, node: &'ast syn::ImplItemFn) {
110        self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
111    }
112}
113
114impl InstructionCollector {
115    fn collect_function(&mut self, name: String, span: proc_macro2::Span, block: &syn::Block) {
116        let mut collector = FunctionBodyCollector::default();
117        collector.visit_block(block);
118
119        self.functions.push(InstructionFunction {
120            qualified_name: self.qualified_name(&name),
121            name,
122            span: span_of(span),
123            guards: collector.guards,
124            calls: collector.calls,
125            writes: collector.writes,
126        });
127    }
128
129    fn qualified_name(&self, name: &str) -> String {
130        let mut parts = self.module_stack.clone();
131        if let Some(impl_name) = self.impl_stack.last() {
132            parts.push(impl_name.clone());
133        }
134        parts.push(name.to_string());
135        parts.join("::")
136    }
137}
138
139#[derive(Default)]
140struct FunctionBodyCollector {
141    next_order: usize,
142    guards: Vec<GuardEvidence>,
143    calls: Vec<CallEvidence>,
144    writes: Vec<WriteEvidence>,
145    /// Maps local variable names to the account names found in the struct literal they were
146    /// bound to (e.g. `let accounts = Transfer { from: ctx.accounts.vault, ... }` →
147    /// `"accounts" → ["vault", ...]`). Used to resolve CpiContext args by name.
148    let_bindings: HashMap<String, Vec<String>>,
149}
150
151impl FunctionBodyCollector {
152    fn order(&mut self) -> usize {
153        self.next_order += 1;
154        self.next_order
155    }
156
157    fn push_guard(&mut self, kind: GuardKind, expression: syn::Expr) {
158        let features = ExprFeatures::from_expr(&expression);
159        self.push_guard_text(
160            kind,
161            expression.to_token_stream().to_string(),
162            expression.span(),
163            features,
164        );
165    }
166
167    fn push_guard_text(
168        &mut self,
169        kind: GuardKind,
170        expression: String,
171        span: proc_macro2::Span,
172        features: ExprFeatures,
173    ) {
174        let order = self.order();
175        self.guards.push(GuardEvidence {
176            kind,
177            expression,
178            span: span_of(span),
179            order,
180            references_owner: features.references_owner,
181            references_signer: features.references_signer,
182            references_key: features.references_key,
183        });
184    }
185
186    fn push_call(
187        &mut self,
188        callee: String,
189        span: proc_macro2::Span,
190        cpi_account_names: Vec<String>,
191    ) {
192        let order = self.order();
193        self.calls.push(CallEvidence {
194            kind: classify_call_kind(&callee),
195            callee,
196            span: span_of(span),
197            order,
198            cpi_account_names,
199        });
200    }
201
202    /// Walk `expr` and return the account names it refers to, following variable
203    /// bindings recorded in `self.let_bindings`. Handles:
204    /// - struct literals: `Transfer { from: ctx.accounts.X, ... }` → `["X", ...]`
205    /// - `CpiContext::new(prog, accounts_expr)` → recurse on accounts_expr
206    /// - variable paths: look up in `let_bindings`
207    /// - references `&expr`: strip and recurse
208    fn extract_account_names_from_expr(&self, expr: &syn::Expr) -> Vec<String> {
209        match expr {
210            syn::Expr::Struct(s) => s
211                .fields
212                .iter()
213                .filter_map(|f| {
214                    let val = normalize_tokens(&f.expr.to_token_stream().to_string());
215                    extract_account_name_from_str(&val)
216                })
217                .collect(),
218            syn::Expr::Call(call) => {
219                let func = normalize_tokens(&call.func.to_token_stream().to_string());
220                if func.contains("CpiContext::new") {
221                    if let Some(accounts_arg) = call.args.iter().nth(1) {
222                        return self.extract_account_names_from_expr(accounts_arg);
223                    }
224                }
225                vec![]
226            }
227            syn::Expr::Path(p) => {
228                let var = p
229                    .path
230                    .segments
231                    .last()
232                    .map(|s| s.ident.to_string())
233                    .unwrap_or_default();
234                self.let_bindings.get(&var).cloned().unwrap_or_default()
235            }
236            syn::Expr::Reference(r) => self.extract_account_names_from_expr(&r.expr),
237            _ => vec![],
238        }
239    }
240
241    fn push_write(&mut self, target: String, span: proc_macro2::Span) {
242        let order = self.order();
243        self.writes.push(WriteEvidence {
244            target,
245            span: span_of(span),
246            order,
247        });
248    }
249
250    fn record_guard_macro(
251        &mut self,
252        path: &syn::Path,
253        tokens: &proc_macro2::TokenStream,
254        span: proc_macro2::Span,
255    ) {
256        if let Some(kind) = classify_guard_macro(path) {
257            if let Some((expression, features)) = macro_guard_payload(path, tokens) {
258                self.push_guard_text(kind, expression, span, features);
259            }
260        }
261    }
262}
263
264impl<'ast> Visit<'ast> for FunctionBodyCollector {
265    fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
266        if let syn::Stmt::Macro(stmt) = node {
267            self.record_guard_macro(&stmt.mac.path, &stmt.mac.tokens, stmt.mac.span());
268        }
269
270        visit::visit_stmt(self, node);
271    }
272
273    fn visit_expr_if(&mut self, node: &'ast syn::ExprIf) {
274        self.push_guard(GuardKind::IfCondition, (*node.cond).clone());
275        visit::visit_expr_if(self, node);
276    }
277
278    fn visit_expr_macro(&mut self, node: &'ast syn::ExprMacro) {
279        self.record_guard_macro(&node.mac.path, &node.mac.tokens, node.span());
280
281        visit::visit_expr_macro(self, node);
282    }
283
284    fn visit_local(&mut self, node: &'ast syn::Local) {
285        if let (Some(init), Some(var_name)) = (&node.init, get_simple_pat_ident(&node.pat)) {
286            let names = self.extract_account_names_from_expr(&init.expr);
287            if !names.is_empty() {
288                self.let_bindings.insert(var_name, names);
289            }
290        }
291        visit::visit_local(self, node);
292    }
293
294    fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
295        let callee = normalize_tokens(&node.func.to_token_stream().to_string());
296        let cpi_account_names = if classify_call_kind(&callee) == CallKind::Cpi {
297            let mut found = vec![];
298            for arg in &node.args {
299                let names = self.extract_account_names_from_expr(arg);
300                if !names.is_empty() {
301                    found = names;
302                    break;
303                }
304            }
305            found
306        } else {
307            vec![]
308        };
309        self.push_call(callee, node.span(), cpi_account_names);
310        visit::visit_expr_call(self, node);
311    }
312
313    fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
314        let receiver = normalize_tokens(&node.receiver.to_token_stream().to_string());
315        let callee = format!("{receiver}.{}", node.method);
316        self.push_call(callee, node.span(), vec![]);
317        visit::visit_expr_method_call(self, node);
318    }
319
320    fn visit_expr_assign(&mut self, node: &'ast syn::ExprAssign) {
321        self.push_write(
322            normalize_tokens(&node.left.to_token_stream().to_string()),
323            node.span(),
324        );
325        visit::visit_expr_assign(self, node);
326    }
327
328    fn visit_expr_binary(&mut self, node: &'ast syn::ExprBinary) {
329        if is_assign_op(&node.op) {
330            self.push_write(
331                normalize_tokens(&node.left.to_token_stream().to_string()),
332                node.span(),
333            );
334        }
335
336        visit::visit_expr_binary(self, node);
337    }
338}
339
340#[derive(Default)]
341struct ExprFeatures {
342    references_owner: bool,
343    references_signer: bool,
344    references_key: bool,
345}
346
347impl ExprFeatures {
348    fn from_expr(expr: &syn::Expr) -> Self {
349        let mut collector = ExprFeatureCollector::default();
350        collector.visit_expr(expr);
351        collector.features
352    }
353
354    fn merge(self, other: Self) -> Self {
355        Self {
356            references_owner: self.references_owner || other.references_owner,
357            references_signer: self.references_signer || other.references_signer,
358            references_key: self.references_key || other.references_key,
359        }
360    }
361}
362
363#[derive(Default)]
364struct ExprFeatureCollector {
365    features: ExprFeatures,
366}
367
368impl<'ast> Visit<'ast> for ExprFeatureCollector {
369    fn visit_expr_field(&mut self, node: &'ast syn::ExprField) {
370        if let syn::Member::Named(member) = &node.member {
371            self.record_ident(member);
372        }
373        visit::visit_expr_field(self, node);
374    }
375
376    fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
377        self.record_ident(&node.method);
378        visit::visit_expr_method_call(self, node);
379    }
380
381    fn visit_path(&mut self, node: &'ast syn::Path) {
382        for segment in &node.segments {
383            self.record_ident(&segment.ident);
384        }
385        visit::visit_path(self, node);
386    }
387}
388
389impl ExprFeatureCollector {
390    fn record_ident(&mut self, ident: &syn::Ident) {
391        match ident.to_string().as_str() {
392            "owner" => self.features.references_owner = true,
393            "is_signer" | "signer" => self.features.references_signer = true,
394            "key" => self.features.references_key = true,
395            _ => {}
396        }
397    }
398}
399
400fn classify_guard_macro(path: &syn::Path) -> Option<GuardKind> {
401    let ident = path.segments.last()?.ident.to_string();
402    if ident.starts_with("require") {
403        Some(GuardKind::RequireMacro)
404    } else if ident.starts_with("assert") {
405        Some(GuardKind::AssertMacro)
406    } else {
407        None
408    }
409}
410
411fn parse_macro_guard_args(tokens: &proc_macro2::TokenStream) -> Option<Vec<syn::Expr>> {
412    let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
413    let args = parser.parse2(tokens.clone()).ok()?;
414    Some(args.into_iter().collect())
415}
416
417fn macro_guard_payload(
418    path: &syn::Path,
419    tokens: &proc_macro2::TokenStream,
420) -> Option<(String, ExprFeatures)> {
421    let args = parse_macro_guard_args(tokens)?;
422    if args.is_empty() {
423        return None;
424    }
425
426    let ident = path.segments.last()?.ident.to_string();
427    if (ident.ends_with("_eq") || ident == "assert_eq") && args.len() >= 2 {
428        let expression = format!(
429            "{} == {}",
430            args[0].to_token_stream(),
431            args[1].to_token_stream()
432        );
433        let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
434        return Some((expression, features));
435    }
436
437    if (ident.ends_with("_ne") || ident == "assert_ne") && args.len() >= 2 {
438        let expression = format!(
439            "{} != {}",
440            args[0].to_token_stream(),
441            args[1].to_token_stream()
442        );
443        let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
444        return Some((expression, features));
445    }
446
447    let first = args.into_iter().next()?;
448    let features = ExprFeatures::from_expr(&first);
449    Some((first.to_token_stream().to_string(), features))
450}
451
452fn classify_call_kind(callee: &str) -> CallKind {
453    let normalized = normalize_tokens(callee);
454    let lower = normalized.to_lowercase();
455
456    if normalized.ends_with(".reload") || normalized.ends_with("::reload") {
457        return CallKind::Reload;
458    }
459
460    if lower.contains("try_deserialize")
461        || normalized.ends_with("::try_from")
462        || normalized.ends_with("::from_account_info")
463        || normalized.ends_with(".load")
464        || normalized.ends_with(".load_mut")
465    {
466        return CallKind::Deserialization;
467    }
468
469    if normalized == "invoke"
470        || normalized == "invoke_signed"
471        || normalized.ends_with("::invoke")
472        || normalized.ends_with("::invoke_signed")
473        || normalized.contains("CpiContext::new")
474        || normalized.contains("CpiContext::new_with_signer")
475        || normalized.starts_with("token::")
476        || normalized.contains("anchor_spl::token::")
477    {
478        return CallKind::Cpi;
479    }
480
481    CallKind::Other
482}
483
484fn is_assign_op(op: &syn::BinOp) -> bool {
485    matches!(
486        op,
487        syn::BinOp::AddAssign(_)
488            | syn::BinOp::SubAssign(_)
489            | syn::BinOp::MulAssign(_)
490            | syn::BinOp::DivAssign(_)
491            | syn::BinOp::RemAssign(_)
492            | syn::BinOp::BitXorAssign(_)
493            | syn::BinOp::BitAndAssign(_)
494            | syn::BinOp::BitOrAssign(_)
495            | syn::BinOp::ShlAssign(_)
496            | syn::BinOp::ShrAssign(_)
497    )
498}
499
500fn normalize_tokens(tokens: &str) -> String {
501    tokens.split_whitespace().collect()
502}
503
504fn get_simple_pat_ident(pat: &syn::Pat) -> Option<String> {
505    if let syn::Pat::Ident(p) = pat {
506        Some(p.ident.to_string())
507    } else {
508        None
509    }
510}
511
512/// Extract an account name from an expression string by looking for `.accounts.IDENT`.
513/// Returns `None` when the expression doesn't reference `ctx.accounts`.
514fn extract_account_name_from_str(s: &str) -> Option<String> {
515    let pos = s.find(".accounts.")?;
516    let after = &s[pos + ".accounts.".len()..];
517    let ident: String = after
518        .chars()
519        .take_while(|c| c.is_alphanumeric() || *c == '_')
520        .collect();
521    if ident.is_empty() {
522        None
523    } else {
524        Some(ident)
525    }
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    fn parse_file(source: &str) -> syn::File {
533        syn::parse_file(source).expect("source should parse")
534    }
535
536    #[test]
537    fn collects_functions_from_modules_and_impls() {
538        let file = parse_file(
539            r#"
540            mod instructions {
541                pub fn process() {}
542            }
543
544            impl Processor {
545                pub fn handle() {}
546            }
547            "#,
548        );
549
550        let index = collect_instruction_index(&file);
551        assert_eq!(index.functions.len(), 2);
552        assert_eq!(index.functions[0].qualified_name, "instructions::process");
553        assert_eq!(index.functions[1].qualified_name, "Processor::handle");
554    }
555
556    #[test]
557    fn models_guards_calls_and_writes_in_order() {
558        let file = parse_file(
559            r#"
560            pub fn process(mut state: Account<'info, Vault>, authority: Signer<'info>) -> Result<()> {
561                if state.owner != authority.key() {
562                    return Err(ErrorCode::Unauthorized.into());
563                }
564
565                require!(authority.is_signer, ErrorCode::Unauthorized);
566
567                let account = Vault::try_deserialize(&mut data)?;
568                invoke_signed(&ix, &accounts, signer_seeds)?;
569                state.reload()?;
570                state.counter += account.amount;
571                state.authority = authority.key();
572                Ok(())
573            }
574            "#,
575        );
576
577        let index = collect_instruction_index(&file);
578        assert_eq!(index.functions.len(), 1);
579
580        let function = &index.functions[0];
581        assert_eq!(function.qualified_name, "process");
582        assert_eq!(function.guards.len(), 2);
583        assert!(function.guards.iter().any(|guard| guard.references_owner));
584        assert!(function.guards.iter().any(|guard| guard.references_signer));
585        assert!(function.guards.iter().any(|guard| guard.references_key));
586
587        assert!(function.calls.iter().any(|call| {
588            call.kind == CallKind::Deserialization && call.callee.contains("Vault::try_deserialize")
589        }));
590        assert!(function
591            .calls
592            .iter()
593            .any(|call| { call.kind == CallKind::Cpi && call.callee.contains("invoke_signed") }));
594        assert!(function
595            .calls
596            .iter()
597            .any(|call| { call.kind == CallKind::Reload && call.callee.contains("state.reload") }));
598
599        assert!(function
600            .writes
601            .iter()
602            .any(|write| write.target == "state.counter"));
603        assert!(function
604            .writes
605            .iter()
606            .any(|write| write.target == "state.authority"));
607
608        let cpi_order = function
609            .calls
610            .iter()
611            .find(|call| call.kind == CallKind::Cpi)
612            .expect("cpi call should be recorded")
613            .order;
614        let reload_order = function
615            .calls
616            .iter()
617            .find(|call| call.kind == CallKind::Reload)
618            .expect("reload call should be recorded")
619            .order;
620        let write_order = function
621            .writes
622            .iter()
623            .find(|write| write.target == "state.counter")
624            .expect("counter write should be recorded")
625            .order;
626
627        assert!(cpi_order < reload_order);
628        assert!(reload_order < write_order);
629    }
630
631    #[test]
632    fn models_eq_style_guard_macros_with_both_operands() {
633        let file = parse_file(
634            r#"
635            pub fn process(account: AccountInfo<'info>, authority: Signer<'info>) -> Result<()> {
636                require_keys_eq!(account.owner, authority.key(), ErrorCode::Unauthorized);
637                Ok(())
638            }
639            "#,
640        );
641
642        let index = collect_instruction_index(&file);
643        let function = &index.functions[0];
644        let guard = function
645            .guards
646            .iter()
647            .find(|guard| guard.kind == GuardKind::RequireMacro)
648            .expect("require macro guard should be recorded");
649
650        assert!(guard.expression.contains("=="));
651        assert!(guard.references_owner);
652        assert!(guard.references_key);
653    }
654}