1use crate::ast_index::{span_of, AstSpan};
2use quote::ToTokens;
3use serde::Serialize;
4use std::collections::HashMap;
5use syn::parse::Parser;
6use syn::spanned::Spanned;
7use syn::visit::{self, Visit};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
10pub struct InstructionIndex {
11 pub functions: Vec<InstructionFunction>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
15pub struct InstructionFunction {
16 pub name: String,
17 pub qualified_name: String,
18 pub span: AstSpan,
19 pub guards: Vec<GuardEvidence>,
20 pub calls: Vec<CallEvidence>,
21 pub writes: Vec<WriteEvidence>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
25pub struct GuardEvidence {
26 pub kind: GuardKind,
27 pub expression: String,
28 pub span: AstSpan,
29 pub order: usize,
30 pub references_owner: bool,
31 pub references_signer: bool,
32 pub references_key: bool,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
36#[serde(rename_all = "snake_case")]
37pub enum GuardKind {
38 IfCondition,
39 RequireMacro,
40 AssertMacro,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
44pub struct CallEvidence {
45 pub kind: CallKind,
46 pub callee: String,
47 pub span: AstSpan,
48 pub order: usize,
49 pub cpi_account_names: Vec<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
55#[serde(rename_all = "snake_case")]
56pub enum CallKind {
57 Deserialization,
58 Cpi,
59 Reload,
60 Other,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
64pub struct WriteEvidence {
65 pub target: String,
66 pub span: AstSpan,
67 pub order: usize,
68}
69
70pub fn collect_instruction_index(file: &syn::File) -> InstructionIndex {
71 let mut collector = InstructionCollector::default();
72 collector.visit_file(file);
73 InstructionIndex {
74 functions: collector.functions,
75 }
76}
77
78#[derive(Default)]
79struct InstructionCollector {
80 functions: Vec<InstructionFunction>,
81 module_stack: Vec<String>,
82 impl_stack: Vec<String>,
83}
84
85impl<'ast> Visit<'ast> for InstructionCollector {
86 fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
87 self.module_stack.push(node.ident.to_string());
88
89 if let Some((_, items)) = &node.content {
90 for item in items {
91 self.visit_item(item);
92 }
93 }
94
95 self.module_stack.pop();
96 }
97
98 fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
99 self.impl_stack
100 .push(node.self_ty.to_token_stream().to_string());
101 visit::visit_item_impl(self, node);
102 self.impl_stack.pop();
103 }
104
105 fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
106 self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
107 }
108
109 fn visit_impl_item_fn(&mut self, node: &'ast syn::ImplItemFn) {
110 self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
111 }
112}
113
114impl InstructionCollector {
115 fn collect_function(&mut self, name: String, span: proc_macro2::Span, block: &syn::Block) {
116 let mut collector = FunctionBodyCollector::default();
117 collector.visit_block(block);
118
119 self.functions.push(InstructionFunction {
120 qualified_name: self.qualified_name(&name),
121 name,
122 span: span_of(span),
123 guards: collector.guards,
124 calls: collector.calls,
125 writes: collector.writes,
126 });
127 }
128
129 fn qualified_name(&self, name: &str) -> String {
130 let mut parts = self.module_stack.clone();
131 if let Some(impl_name) = self.impl_stack.last() {
132 parts.push(impl_name.clone());
133 }
134 parts.push(name.to_string());
135 parts.join("::")
136 }
137}
138
139#[derive(Default)]
140struct FunctionBodyCollector {
141 next_order: usize,
142 guards: Vec<GuardEvidence>,
143 calls: Vec<CallEvidence>,
144 writes: Vec<WriteEvidence>,
145 let_bindings: HashMap<String, Vec<String>>,
149}
150
151impl FunctionBodyCollector {
152 fn order(&mut self) -> usize {
153 self.next_order += 1;
154 self.next_order
155 }
156
157 fn push_guard(&mut self, kind: GuardKind, expression: syn::Expr) {
158 let features = ExprFeatures::from_expr(&expression);
159 self.push_guard_text(
160 kind,
161 expression.to_token_stream().to_string(),
162 expression.span(),
163 features,
164 );
165 }
166
167 fn push_guard_text(
168 &mut self,
169 kind: GuardKind,
170 expression: String,
171 span: proc_macro2::Span,
172 features: ExprFeatures,
173 ) {
174 let order = self.order();
175 self.guards.push(GuardEvidence {
176 kind,
177 expression,
178 span: span_of(span),
179 order,
180 references_owner: features.references_owner,
181 references_signer: features.references_signer,
182 references_key: features.references_key,
183 });
184 }
185
186 fn push_call(
187 &mut self,
188 callee: String,
189 span: proc_macro2::Span,
190 cpi_account_names: Vec<String>,
191 ) {
192 let order = self.order();
193 self.calls.push(CallEvidence {
194 kind: classify_call_kind(&callee),
195 callee,
196 span: span_of(span),
197 order,
198 cpi_account_names,
199 });
200 }
201
202 fn extract_account_names_from_expr(&self, expr: &syn::Expr) -> Vec<String> {
209 match expr {
210 syn::Expr::Struct(s) => s
211 .fields
212 .iter()
213 .filter_map(|f| {
214 let val = normalize_tokens(&f.expr.to_token_stream().to_string());
215 extract_account_name_from_str(&val)
216 })
217 .collect(),
218 syn::Expr::Call(call) => {
219 let func = normalize_tokens(&call.func.to_token_stream().to_string());
220 if func.contains("CpiContext::new") {
221 if let Some(accounts_arg) = call.args.iter().nth(1) {
222 return self.extract_account_names_from_expr(accounts_arg);
223 }
224 }
225 vec![]
226 }
227 syn::Expr::Path(p) => {
228 let var = p
229 .path
230 .segments
231 .last()
232 .map(|s| s.ident.to_string())
233 .unwrap_or_default();
234 self.let_bindings.get(&var).cloned().unwrap_or_default()
235 }
236 syn::Expr::Reference(r) => self.extract_account_names_from_expr(&r.expr),
237 _ => vec![],
238 }
239 }
240
241 fn push_write(&mut self, target: String, span: proc_macro2::Span) {
242 let order = self.order();
243 self.writes.push(WriteEvidence {
244 target,
245 span: span_of(span),
246 order,
247 });
248 }
249
250 fn record_guard_macro(
251 &mut self,
252 path: &syn::Path,
253 tokens: &proc_macro2::TokenStream,
254 span: proc_macro2::Span,
255 ) {
256 if let Some(kind) = classify_guard_macro(path) {
257 if let Some((expression, features)) = macro_guard_payload(path, tokens) {
258 self.push_guard_text(kind, expression, span, features);
259 }
260 }
261 }
262}
263
264impl<'ast> Visit<'ast> for FunctionBodyCollector {
265 fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
266 if let syn::Stmt::Macro(stmt) = node {
267 self.record_guard_macro(&stmt.mac.path, &stmt.mac.tokens, stmt.mac.span());
268 }
269
270 visit::visit_stmt(self, node);
271 }
272
273 fn visit_expr_if(&mut self, node: &'ast syn::ExprIf) {
274 self.push_guard(GuardKind::IfCondition, (*node.cond).clone());
275 visit::visit_expr_if(self, node);
276 }
277
278 fn visit_expr_macro(&mut self, node: &'ast syn::ExprMacro) {
279 self.record_guard_macro(&node.mac.path, &node.mac.tokens, node.span());
280
281 visit::visit_expr_macro(self, node);
282 }
283
284 fn visit_local(&mut self, node: &'ast syn::Local) {
285 if let (Some(init), Some(var_name)) = (&node.init, get_simple_pat_ident(&node.pat)) {
286 let names = self.extract_account_names_from_expr(&init.expr);
287 if !names.is_empty() {
288 self.let_bindings.insert(var_name, names);
289 }
290 }
291 visit::visit_local(self, node);
292 }
293
294 fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
295 let callee = normalize_tokens(&node.func.to_token_stream().to_string());
296 let cpi_account_names = if classify_call_kind(&callee) == CallKind::Cpi {
297 let mut found = vec![];
298 for arg in &node.args {
299 let names = self.extract_account_names_from_expr(arg);
300 if !names.is_empty() {
301 found = names;
302 break;
303 }
304 }
305 found
306 } else {
307 vec![]
308 };
309 self.push_call(callee, node.span(), cpi_account_names);
310 visit::visit_expr_call(self, node);
311 }
312
313 fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
314 let receiver = normalize_tokens(&node.receiver.to_token_stream().to_string());
315 let callee = format!("{receiver}.{}", node.method);
316 self.push_call(callee, node.span(), vec![]);
317 visit::visit_expr_method_call(self, node);
318 }
319
320 fn visit_expr_assign(&mut self, node: &'ast syn::ExprAssign) {
321 self.push_write(
322 normalize_tokens(&node.left.to_token_stream().to_string()),
323 node.span(),
324 );
325 visit::visit_expr_assign(self, node);
326 }
327
328 fn visit_expr_binary(&mut self, node: &'ast syn::ExprBinary) {
329 if is_assign_op(&node.op) {
330 self.push_write(
331 normalize_tokens(&node.left.to_token_stream().to_string()),
332 node.span(),
333 );
334 }
335
336 visit::visit_expr_binary(self, node);
337 }
338}
339
340#[derive(Default)]
341struct ExprFeatures {
342 references_owner: bool,
343 references_signer: bool,
344 references_key: bool,
345}
346
347impl ExprFeatures {
348 fn from_expr(expr: &syn::Expr) -> Self {
349 let mut collector = ExprFeatureCollector::default();
350 collector.visit_expr(expr);
351 collector.features
352 }
353
354 fn merge(self, other: Self) -> Self {
355 Self {
356 references_owner: self.references_owner || other.references_owner,
357 references_signer: self.references_signer || other.references_signer,
358 references_key: self.references_key || other.references_key,
359 }
360 }
361}
362
363#[derive(Default)]
364struct ExprFeatureCollector {
365 features: ExprFeatures,
366}
367
368impl<'ast> Visit<'ast> for ExprFeatureCollector {
369 fn visit_expr_field(&mut self, node: &'ast syn::ExprField) {
370 if let syn::Member::Named(member) = &node.member {
371 self.record_ident(member);
372 }
373 visit::visit_expr_field(self, node);
374 }
375
376 fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
377 self.record_ident(&node.method);
378 visit::visit_expr_method_call(self, node);
379 }
380
381 fn visit_path(&mut self, node: &'ast syn::Path) {
382 for segment in &node.segments {
383 self.record_ident(&segment.ident);
384 }
385 visit::visit_path(self, node);
386 }
387}
388
389impl ExprFeatureCollector {
390 fn record_ident(&mut self, ident: &syn::Ident) {
391 match ident.to_string().as_str() {
392 "owner" => self.features.references_owner = true,
393 "is_signer" | "signer" => self.features.references_signer = true,
394 "key" => self.features.references_key = true,
395 _ => {}
396 }
397 }
398}
399
400fn classify_guard_macro(path: &syn::Path) -> Option<GuardKind> {
401 let ident = path.segments.last()?.ident.to_string();
402 if ident.starts_with("require") {
403 Some(GuardKind::RequireMacro)
404 } else if ident.starts_with("assert") {
405 Some(GuardKind::AssertMacro)
406 } else {
407 None
408 }
409}
410
411fn parse_macro_guard_args(tokens: &proc_macro2::TokenStream) -> Option<Vec<syn::Expr>> {
412 let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
413 let args = parser.parse2(tokens.clone()).ok()?;
414 Some(args.into_iter().collect())
415}
416
417fn macro_guard_payload(
418 path: &syn::Path,
419 tokens: &proc_macro2::TokenStream,
420) -> Option<(String, ExprFeatures)> {
421 let args = parse_macro_guard_args(tokens)?;
422 if args.is_empty() {
423 return None;
424 }
425
426 let ident = path.segments.last()?.ident.to_string();
427 if (ident.ends_with("_eq") || ident == "assert_eq") && args.len() >= 2 {
428 let expression = format!(
429 "{} == {}",
430 args[0].to_token_stream(),
431 args[1].to_token_stream()
432 );
433 let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
434 return Some((expression, features));
435 }
436
437 if (ident.ends_with("_ne") || ident == "assert_ne") && args.len() >= 2 {
438 let expression = format!(
439 "{} != {}",
440 args[0].to_token_stream(),
441 args[1].to_token_stream()
442 );
443 let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
444 return Some((expression, features));
445 }
446
447 let first = args.into_iter().next()?;
448 let features = ExprFeatures::from_expr(&first);
449 Some((first.to_token_stream().to_string(), features))
450}
451
452fn classify_call_kind(callee: &str) -> CallKind {
453 let normalized = normalize_tokens(callee);
454 let lower = normalized.to_lowercase();
455
456 if normalized.ends_with(".reload") || normalized.ends_with("::reload") {
457 return CallKind::Reload;
458 }
459
460 if lower.contains("try_deserialize")
461 || normalized.ends_with("::try_from")
462 || normalized.ends_with("::from_account_info")
463 || normalized.ends_with(".load")
464 || normalized.ends_with(".load_mut")
465 {
466 return CallKind::Deserialization;
467 }
468
469 if normalized == "invoke"
470 || normalized == "invoke_signed"
471 || normalized.ends_with("::invoke")
472 || normalized.ends_with("::invoke_signed")
473 || normalized.contains("CpiContext::new")
474 || normalized.contains("CpiContext::new_with_signer")
475 || normalized.starts_with("token::")
476 || normalized.contains("anchor_spl::token::")
477 {
478 return CallKind::Cpi;
479 }
480
481 CallKind::Other
482}
483
484fn is_assign_op(op: &syn::BinOp) -> bool {
485 matches!(
486 op,
487 syn::BinOp::AddAssign(_)
488 | syn::BinOp::SubAssign(_)
489 | syn::BinOp::MulAssign(_)
490 | syn::BinOp::DivAssign(_)
491 | syn::BinOp::RemAssign(_)
492 | syn::BinOp::BitXorAssign(_)
493 | syn::BinOp::BitAndAssign(_)
494 | syn::BinOp::BitOrAssign(_)
495 | syn::BinOp::ShlAssign(_)
496 | syn::BinOp::ShrAssign(_)
497 )
498}
499
500fn normalize_tokens(tokens: &str) -> String {
501 tokens.split_whitespace().collect()
502}
503
504fn get_simple_pat_ident(pat: &syn::Pat) -> Option<String> {
505 if let syn::Pat::Ident(p) = pat {
506 Some(p.ident.to_string())
507 } else {
508 None
509 }
510}
511
512fn extract_account_name_from_str(s: &str) -> Option<String> {
515 let pos = s.find(".accounts.")?;
516 let after = &s[pos + ".accounts.".len()..];
517 let ident: String = after
518 .chars()
519 .take_while(|c| c.is_alphanumeric() || *c == '_')
520 .collect();
521 if ident.is_empty() {
522 None
523 } else {
524 Some(ident)
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 fn parse_file(source: &str) -> syn::File {
533 syn::parse_file(source).expect("source should parse")
534 }
535
536 #[test]
537 fn collects_functions_from_modules_and_impls() {
538 let file = parse_file(
539 r#"
540 mod instructions {
541 pub fn process() {}
542 }
543
544 impl Processor {
545 pub fn handle() {}
546 }
547 "#,
548 );
549
550 let index = collect_instruction_index(&file);
551 assert_eq!(index.functions.len(), 2);
552 assert_eq!(index.functions[0].qualified_name, "instructions::process");
553 assert_eq!(index.functions[1].qualified_name, "Processor::handle");
554 }
555
556 #[test]
557 fn models_guards_calls_and_writes_in_order() {
558 let file = parse_file(
559 r#"
560 pub fn process(mut state: Account<'info, Vault>, authority: Signer<'info>) -> Result<()> {
561 if state.owner != authority.key() {
562 return Err(ErrorCode::Unauthorized.into());
563 }
564
565 require!(authority.is_signer, ErrorCode::Unauthorized);
566
567 let account = Vault::try_deserialize(&mut data)?;
568 invoke_signed(&ix, &accounts, signer_seeds)?;
569 state.reload()?;
570 state.counter += account.amount;
571 state.authority = authority.key();
572 Ok(())
573 }
574 "#,
575 );
576
577 let index = collect_instruction_index(&file);
578 assert_eq!(index.functions.len(), 1);
579
580 let function = &index.functions[0];
581 assert_eq!(function.qualified_name, "process");
582 assert_eq!(function.guards.len(), 2);
583 assert!(function.guards.iter().any(|guard| guard.references_owner));
584 assert!(function.guards.iter().any(|guard| guard.references_signer));
585 assert!(function.guards.iter().any(|guard| guard.references_key));
586
587 assert!(function.calls.iter().any(|call| {
588 call.kind == CallKind::Deserialization && call.callee.contains("Vault::try_deserialize")
589 }));
590 assert!(function
591 .calls
592 .iter()
593 .any(|call| { call.kind == CallKind::Cpi && call.callee.contains("invoke_signed") }));
594 assert!(function
595 .calls
596 .iter()
597 .any(|call| { call.kind == CallKind::Reload && call.callee.contains("state.reload") }));
598
599 assert!(function
600 .writes
601 .iter()
602 .any(|write| write.target == "state.counter"));
603 assert!(function
604 .writes
605 .iter()
606 .any(|write| write.target == "state.authority"));
607
608 let cpi_order = function
609 .calls
610 .iter()
611 .find(|call| call.kind == CallKind::Cpi)
612 .expect("cpi call should be recorded")
613 .order;
614 let reload_order = function
615 .calls
616 .iter()
617 .find(|call| call.kind == CallKind::Reload)
618 .expect("reload call should be recorded")
619 .order;
620 let write_order = function
621 .writes
622 .iter()
623 .find(|write| write.target == "state.counter")
624 .expect("counter write should be recorded")
625 .order;
626
627 assert!(cpi_order < reload_order);
628 assert!(reload_order < write_order);
629 }
630
631 #[test]
632 fn models_eq_style_guard_macros_with_both_operands() {
633 let file = parse_file(
634 r#"
635 pub fn process(account: AccountInfo<'info>, authority: Signer<'info>) -> Result<()> {
636 require_keys_eq!(account.owner, authority.key(), ErrorCode::Unauthorized);
637 Ok(())
638 }
639 "#,
640 );
641
642 let index = collect_instruction_index(&file);
643 let function = &index.functions[0];
644 let guard = function
645 .guards
646 .iter()
647 .find(|guard| guard.kind == GuardKind::RequireMacro)
648 .expect("require macro guard should be recorded");
649
650 assert!(guard.expression.contains("=="));
651 assert!(guard.references_owner);
652 assert!(guard.references_key);
653 }
654}