1use crate::ast_index::{span_of, AstSpan};
2use quote::ToTokens;
3use serde::Serialize;
4use std::collections::HashMap;
5use syn::parse::Parser;
6use syn::spanned::Spanned;
7use syn::visit::{self, Visit};
8
9#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
10pub struct InstructionIndex {
11 pub functions: Vec<InstructionFunction>,
12}
13
14#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
15pub struct InstructionFunction {
16 pub name: String,
17 pub qualified_name: String,
18 pub span: AstSpan,
19 pub guards: Vec<GuardEvidence>,
20 pub calls: Vec<CallEvidence>,
21 pub writes: Vec<WriteEvidence>,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
25pub struct GuardEvidence {
26 pub kind: GuardKind,
27 pub expression: String,
28 pub span: AstSpan,
29 pub order: usize,
30 pub references_owner: bool,
31 pub references_signer: bool,
32 pub references_key: bool,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
36#[serde(rename_all = "snake_case")]
37pub enum GuardKind {
38 IfCondition,
39 RequireMacro,
40 AssertMacro,
41}
42
43#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
44pub struct CallEvidence {
45 pub kind: CallKind,
46 pub callee: String,
47 pub span: AstSpan,
48 pub order: usize,
49 pub cpi_account_names: Vec<String>,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
55#[serde(rename_all = "snake_case")]
56pub enum CallKind {
57 Deserialization,
58 Cpi,
59 Reload,
60 Other,
61}
62
63#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
64pub struct WriteEvidence {
65 pub target: String,
66 pub span: AstSpan,
67 pub order: usize,
68}
69
70pub fn collect_instruction_index(file: &syn::File) -> InstructionIndex {
71 let mut collector = InstructionCollector::default();
72 collector.visit_file(file);
73 InstructionIndex {
74 functions: collector.functions,
75 }
76}
77
78#[derive(Default)]
79struct InstructionCollector {
80 functions: Vec<InstructionFunction>,
81 module_stack: Vec<String>,
82 impl_stack: Vec<String>,
83}
84
85impl<'ast> Visit<'ast> for InstructionCollector {
86 fn visit_item_mod(&mut self, node: &'ast syn::ItemMod) {
87 self.module_stack.push(node.ident.to_string());
88
89 if let Some((_, items)) = &node.content {
90 for item in items {
91 self.visit_item(item);
92 }
93 }
94
95 self.module_stack.pop();
96 }
97
98 fn visit_item_impl(&mut self, node: &'ast syn::ItemImpl) {
99 self.impl_stack.push(node.self_ty.to_token_stream().to_string());
100 visit::visit_item_impl(self, node);
101 self.impl_stack.pop();
102 }
103
104 fn visit_item_fn(&mut self, node: &'ast syn::ItemFn) {
105 self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
106 }
107
108 fn visit_impl_item_fn(&mut self, node: &'ast syn::ImplItemFn) {
109 self.collect_function(node.sig.ident.to_string(), node.span(), &node.block);
110 }
111}
112
113impl InstructionCollector {
114 fn collect_function(&mut self, name: String, span: proc_macro2::Span, block: &syn::Block) {
115 let mut collector = FunctionBodyCollector::default();
116 collector.visit_block(block);
117
118 self.functions.push(InstructionFunction {
119 qualified_name: self.qualified_name(&name),
120 name,
121 span: span_of(span),
122 guards: collector.guards,
123 calls: collector.calls,
124 writes: collector.writes,
125 });
126 }
127
128 fn qualified_name(&self, name: &str) -> String {
129 let mut parts = self.module_stack.clone();
130 if let Some(impl_name) = self.impl_stack.last() {
131 parts.push(impl_name.clone());
132 }
133 parts.push(name.to_string());
134 parts.join("::")
135 }
136}
137
138#[derive(Default)]
139struct FunctionBodyCollector {
140 next_order: usize,
141 guards: Vec<GuardEvidence>,
142 calls: Vec<CallEvidence>,
143 writes: Vec<WriteEvidence>,
144 let_bindings: HashMap<String, Vec<String>>,
148}
149
150impl FunctionBodyCollector {
151 fn order(&mut self) -> usize {
152 self.next_order += 1;
153 self.next_order
154 }
155
156 fn push_guard(&mut self, kind: GuardKind, expression: syn::Expr) {
157 let features = ExprFeatures::from_expr(&expression);
158 self.push_guard_text(kind, expression.to_token_stream().to_string(), expression.span(), features);
159 }
160
161 fn push_guard_text(
162 &mut self,
163 kind: GuardKind,
164 expression: String,
165 span: proc_macro2::Span,
166 features: ExprFeatures,
167 ) {
168 let order = self.order();
169 self.guards.push(GuardEvidence {
170 kind,
171 expression,
172 span: span_of(span),
173 order,
174 references_owner: features.references_owner,
175 references_signer: features.references_signer,
176 references_key: features.references_key,
177 });
178 }
179
180 fn push_call(&mut self, callee: String, span: proc_macro2::Span, cpi_account_names: Vec<String>) {
181 let order = self.order();
182 self.calls.push(CallEvidence {
183 kind: classify_call_kind(&callee),
184 callee,
185 span: span_of(span),
186 order,
187 cpi_account_names,
188 });
189 }
190
191 fn extract_account_names_from_expr(&self, expr: &syn::Expr) -> Vec<String> {
198 match expr {
199 syn::Expr::Struct(s) => s
200 .fields
201 .iter()
202 .filter_map(|f| {
203 let val = normalize_tokens(&f.expr.to_token_stream().to_string());
204 extract_account_name_from_str(&val)
205 })
206 .collect(),
207 syn::Expr::Call(call) => {
208 let func = normalize_tokens(&call.func.to_token_stream().to_string());
209 if func.contains("CpiContext::new") {
210 if let Some(accounts_arg) = call.args.iter().nth(1) {
211 return self.extract_account_names_from_expr(accounts_arg);
212 }
213 }
214 vec![]
215 }
216 syn::Expr::Path(p) => {
217 let var = p
218 .path
219 .segments
220 .last()
221 .map(|s| s.ident.to_string())
222 .unwrap_or_default();
223 self.let_bindings.get(&var).cloned().unwrap_or_default()
224 }
225 syn::Expr::Reference(r) => self.extract_account_names_from_expr(&r.expr),
226 _ => vec![],
227 }
228 }
229
230 fn push_write(&mut self, target: String, span: proc_macro2::Span) {
231 let order = self.order();
232 self.writes.push(WriteEvidence {
233 target,
234 span: span_of(span),
235 order,
236 });
237 }
238
239 fn record_guard_macro(
240 &mut self,
241 path: &syn::Path,
242 tokens: &proc_macro2::TokenStream,
243 span: proc_macro2::Span,
244 ) {
245 if let Some(kind) = classify_guard_macro(path) {
246 if let Some((expression, features)) = macro_guard_payload(path, tokens) {
247 self.push_guard_text(kind, expression, span, features);
248 }
249 }
250 }
251}
252
253impl<'ast> Visit<'ast> for FunctionBodyCollector {
254 fn visit_stmt(&mut self, node: &'ast syn::Stmt) {
255 if let syn::Stmt::Macro(stmt) = node {
256 self.record_guard_macro(&stmt.mac.path, &stmt.mac.tokens, stmt.mac.span());
257 }
258
259 visit::visit_stmt(self, node);
260 }
261
262 fn visit_expr_if(&mut self, node: &'ast syn::ExprIf) {
263 self.push_guard(GuardKind::IfCondition, (*node.cond).clone());
264 visit::visit_expr_if(self, node);
265 }
266
267 fn visit_expr_macro(&mut self, node: &'ast syn::ExprMacro) {
268 self.record_guard_macro(&node.mac.path, &node.mac.tokens, node.span());
269
270 visit::visit_expr_macro(self, node);
271 }
272
273 fn visit_local(&mut self, node: &'ast syn::Local) {
274 if let (Some(init), Some(var_name)) = (&node.init, get_simple_pat_ident(&node.pat)) {
275 let names = self.extract_account_names_from_expr(&init.expr);
276 if !names.is_empty() {
277 self.let_bindings.insert(var_name, names);
278 }
279 }
280 visit::visit_local(self, node);
281 }
282
283 fn visit_expr_call(&mut self, node: &'ast syn::ExprCall) {
284 let callee = normalize_tokens(&node.func.to_token_stream().to_string());
285 let cpi_account_names = if classify_call_kind(&callee) == CallKind::Cpi {
286 let mut found = vec![];
287 for arg in &node.args {
288 let names = self.extract_account_names_from_expr(arg);
289 if !names.is_empty() {
290 found = names;
291 break;
292 }
293 }
294 found
295 } else {
296 vec![]
297 };
298 self.push_call(callee, node.span(), cpi_account_names);
299 visit::visit_expr_call(self, node);
300 }
301
302 fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
303 let receiver = normalize_tokens(&node.receiver.to_token_stream().to_string());
304 let callee = format!("{receiver}.{}", node.method);
305 self.push_call(callee, node.span(), vec![]);
306 visit::visit_expr_method_call(self, node);
307 }
308
309 fn visit_expr_assign(&mut self, node: &'ast syn::ExprAssign) {
310 self.push_write(normalize_tokens(&node.left.to_token_stream().to_string()), node.span());
311 visit::visit_expr_assign(self, node);
312 }
313
314 fn visit_expr_binary(&mut self, node: &'ast syn::ExprBinary) {
315 if is_assign_op(&node.op) {
316 self.push_write(normalize_tokens(&node.left.to_token_stream().to_string()), node.span());
317 }
318
319 visit::visit_expr_binary(self, node);
320 }
321}
322
323#[derive(Default)]
324struct ExprFeatures {
325 references_owner: bool,
326 references_signer: bool,
327 references_key: bool,
328}
329
330impl ExprFeatures {
331 fn from_expr(expr: &syn::Expr) -> Self {
332 let mut collector = ExprFeatureCollector::default();
333 collector.visit_expr(expr);
334 collector.features
335 }
336
337 fn merge(self, other: Self) -> Self {
338 Self {
339 references_owner: self.references_owner || other.references_owner,
340 references_signer: self.references_signer || other.references_signer,
341 references_key: self.references_key || other.references_key,
342 }
343 }
344}
345
346#[derive(Default)]
347struct ExprFeatureCollector {
348 features: ExprFeatures,
349}
350
351impl<'ast> Visit<'ast> for ExprFeatureCollector {
352 fn visit_expr_field(&mut self, node: &'ast syn::ExprField) {
353 if let syn::Member::Named(member) = &node.member {
354 self.record_ident(member);
355 }
356 visit::visit_expr_field(self, node);
357 }
358
359 fn visit_expr_method_call(&mut self, node: &'ast syn::ExprMethodCall) {
360 self.record_ident(&node.method);
361 visit::visit_expr_method_call(self, node);
362 }
363
364 fn visit_path(&mut self, node: &'ast syn::Path) {
365 for segment in &node.segments {
366 self.record_ident(&segment.ident);
367 }
368 visit::visit_path(self, node);
369 }
370}
371
372impl ExprFeatureCollector {
373 fn record_ident(&mut self, ident: &syn::Ident) {
374 match ident.to_string().as_str() {
375 "owner" => self.features.references_owner = true,
376 "is_signer" | "signer" => self.features.references_signer = true,
377 "key" => self.features.references_key = true,
378 _ => {}
379 }
380 }
381}
382
383fn classify_guard_macro(path: &syn::Path) -> Option<GuardKind> {
384 let ident = path.segments.last()?.ident.to_string();
385 if ident.starts_with("require") {
386 Some(GuardKind::RequireMacro)
387 } else if ident.starts_with("assert") {
388 Some(GuardKind::AssertMacro)
389 } else {
390 None
391 }
392}
393
394fn parse_macro_guard_args(tokens: &proc_macro2::TokenStream) -> Option<Vec<syn::Expr>> {
395 let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
396 let args = parser.parse2(tokens.clone()).ok()?;
397 Some(args.into_iter().collect())
398}
399
400fn macro_guard_payload(path: &syn::Path, tokens: &proc_macro2::TokenStream) -> Option<(String, ExprFeatures)> {
401 let args = parse_macro_guard_args(tokens)?;
402 if args.is_empty() {
403 return None;
404 }
405
406 let ident = path.segments.last()?.ident.to_string();
407 if (ident.ends_with("_eq") || ident == "assert_eq") && args.len() >= 2 {
408 let expression = format!(
409 "{} == {}",
410 args[0].to_token_stream(),
411 args[1].to_token_stream()
412 );
413 let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
414 return Some((expression, features));
415 }
416
417 if (ident.ends_with("_ne") || ident == "assert_ne") && args.len() >= 2 {
418 let expression = format!(
419 "{} != {}",
420 args[0].to_token_stream(),
421 args[1].to_token_stream()
422 );
423 let features = ExprFeatures::from_expr(&args[0]).merge(ExprFeatures::from_expr(&args[1]));
424 return Some((expression, features));
425 }
426
427 let first = args.into_iter().next()?;
428 let features = ExprFeatures::from_expr(&first);
429 Some((first.to_token_stream().to_string(), features))
430}
431
432fn classify_call_kind(callee: &str) -> CallKind {
433 let normalized = normalize_tokens(callee);
434 let lower = normalized.to_lowercase();
435
436 if normalized.ends_with(".reload") || normalized.ends_with("::reload") {
437 return CallKind::Reload;
438 }
439
440 if lower.contains("try_deserialize")
441 || normalized.ends_with("::try_from")
442 || normalized.ends_with("::from_account_info")
443 || normalized.ends_with(".load")
444 || normalized.ends_with(".load_mut")
445 {
446 return CallKind::Deserialization;
447 }
448
449 if normalized == "invoke"
450 || normalized == "invoke_signed"
451 || normalized.ends_with("::invoke")
452 || normalized.ends_with("::invoke_signed")
453 || normalized.contains("CpiContext::new")
454 || normalized.contains("CpiContext::new_with_signer")
455 || normalized.starts_with("token::")
456 || normalized.contains("anchor_spl::token::")
457 {
458 return CallKind::Cpi;
459 }
460
461 CallKind::Other
462}
463
464fn is_assign_op(op: &syn::BinOp) -> bool {
465 matches!(
466 op,
467 syn::BinOp::AddAssign(_)
468 | syn::BinOp::SubAssign(_)
469 | syn::BinOp::MulAssign(_)
470 | syn::BinOp::DivAssign(_)
471 | syn::BinOp::RemAssign(_)
472 | syn::BinOp::BitXorAssign(_)
473 | syn::BinOp::BitAndAssign(_)
474 | syn::BinOp::BitOrAssign(_)
475 | syn::BinOp::ShlAssign(_)
476 | syn::BinOp::ShrAssign(_)
477 )
478}
479
480fn normalize_tokens(tokens: &str) -> String {
481 tokens.split_whitespace().collect()
482}
483
484fn get_simple_pat_ident(pat: &syn::Pat) -> Option<String> {
485 if let syn::Pat::Ident(p) = pat {
486 Some(p.ident.to_string())
487 } else {
488 None
489 }
490}
491
492fn extract_account_name_from_str(s: &str) -> Option<String> {
495 let pos = s.find(".accounts.")?;
496 let after = &s[pos + ".accounts.".len()..];
497 let ident: String = after
498 .chars()
499 .take_while(|c| c.is_alphanumeric() || *c == '_')
500 .collect();
501 if ident.is_empty() { None } else { Some(ident) }
502}
503
504#[cfg(test)]
505mod tests {
506 use super::*;
507
508 fn parse_file(source: &str) -> syn::File {
509 syn::parse_file(source).expect("source should parse")
510 }
511
512 #[test]
513 fn collects_functions_from_modules_and_impls() {
514 let file = parse_file(
515 r#"
516 mod instructions {
517 pub fn process() {}
518 }
519
520 impl Processor {
521 pub fn handle() {}
522 }
523 "#,
524 );
525
526 let index = collect_instruction_index(&file);
527 assert_eq!(index.functions.len(), 2);
528 assert_eq!(index.functions[0].qualified_name, "instructions::process");
529 assert_eq!(index.functions[1].qualified_name, "Processor::handle");
530 }
531
532 #[test]
533 fn models_guards_calls_and_writes_in_order() {
534 let file = parse_file(
535 r#"
536 pub fn process(mut state: Account<'info, Vault>, authority: Signer<'info>) -> Result<()> {
537 if state.owner != authority.key() {
538 return Err(ErrorCode::Unauthorized.into());
539 }
540
541 require!(authority.is_signer, ErrorCode::Unauthorized);
542
543 let account = Vault::try_deserialize(&mut data)?;
544 invoke_signed(&ix, &accounts, signer_seeds)?;
545 state.reload()?;
546 state.counter += account.amount;
547 state.authority = authority.key();
548 Ok(())
549 }
550 "#,
551 );
552
553 let index = collect_instruction_index(&file);
554 assert_eq!(index.functions.len(), 1);
555
556 let function = &index.functions[0];
557 assert_eq!(function.qualified_name, "process");
558 assert_eq!(function.guards.len(), 2);
559 assert!(function.guards.iter().any(|guard| guard.references_owner));
560 assert!(function.guards.iter().any(|guard| guard.references_signer));
561 assert!(function.guards.iter().any(|guard| guard.references_key));
562
563 assert!(function.calls.iter().any(|call| {
564 call.kind == CallKind::Deserialization
565 && call.callee.contains("Vault::try_deserialize")
566 }));
567 assert!(function.calls.iter().any(|call| {
568 call.kind == CallKind::Cpi && call.callee.contains("invoke_signed")
569 }));
570 assert!(function.calls.iter().any(|call| {
571 call.kind == CallKind::Reload && call.callee.contains("state.reload")
572 }));
573
574 assert!(function
575 .writes
576 .iter()
577 .any(|write| write.target == "state.counter"));
578 assert!(function
579 .writes
580 .iter()
581 .any(|write| write.target == "state.authority"));
582
583 let cpi_order = function
584 .calls
585 .iter()
586 .find(|call| call.kind == CallKind::Cpi)
587 .expect("cpi call should be recorded")
588 .order;
589 let reload_order = function
590 .calls
591 .iter()
592 .find(|call| call.kind == CallKind::Reload)
593 .expect("reload call should be recorded")
594 .order;
595 let write_order = function
596 .writes
597 .iter()
598 .find(|write| write.target == "state.counter")
599 .expect("counter write should be recorded")
600 .order;
601
602 assert!(cpi_order < reload_order);
603 assert!(reload_order < write_order);
604 }
605
606 #[test]
607 fn models_eq_style_guard_macros_with_both_operands() {
608 let file = parse_file(
609 r#"
610 pub fn process(account: AccountInfo<'info>, authority: Signer<'info>) -> Result<()> {
611 require_keys_eq!(account.owner, authority.key(), ErrorCode::Unauthorized);
612 Ok(())
613 }
614 "#,
615 );
616
617 let index = collect_instruction_index(&file);
618 let function = &index.functions[0];
619 let guard = function
620 .guards
621 .iter()
622 .find(|guard| guard.kind == GuardKind::RequireMacro)
623 .expect("require macro guard should be recorded");
624
625 assert!(guard.expression.contains("=="));
626 assert!(guard.references_owner);
627 assert!(guard.references_key);
628 }
629}