Skip to main content

sentio_core/
ast_index.rs

1use quote::ToTokens;
2use serde::Serialize;
3use syn::spanned::Spanned;
4use syn::{Attribute, Fields};
5
6#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
7pub struct AstIndex {
8    pub structs: Vec<AstStruct>,
9}
10
11#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
12pub struct AstStruct {
13    pub name: String,
14    pub attrs: Vec<AstAttr>,
15    pub fields: Vec<AstField>,
16    pub span: AstSpan,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
20pub struct AstField {
21    pub name: Option<String>,
22    pub ty: String,
23    pub attrs: Vec<AstAttr>,
24    pub span: AstSpan,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
28pub struct AstAttr {
29    pub path: String,
30    pub tokens: Option<String>,
31    pub span: AstSpan,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
35pub struct AstSpan {
36    pub start_line: usize,
37    pub start_column: usize,
38    pub end_line: usize,
39    pub end_column: usize,
40}
41
42pub fn collect_ast_index(file: &syn::File) -> AstIndex {
43    let mut structs = Vec::new();
44    collect_from_items(&file.items, &mut structs);
45    AstIndex { structs }
46}
47
48fn collect_from_items(items: &[syn::Item], structs: &mut Vec<AstStruct>) {
49    for item in items {
50        match item {
51            syn::Item::Struct(item) => structs.push(ast_struct_from_syn(item)),
52            syn::Item::Mod(module) => {
53                if let Some((_, nested_items)) = &module.content {
54                    collect_from_items(nested_items, structs);
55                }
56            }
57            _ => {}
58        }
59    }
60}
61
62pub(crate) fn ast_struct_from_syn(item: &syn::ItemStruct) -> AstStruct {
63    let fields = match &item.fields {
64        Fields::Named(named) => named.named.iter().map(ast_field_from_syn).collect(),
65        Fields::Unnamed(unnamed) => unnamed.unnamed.iter().map(ast_field_from_syn).collect(),
66        Fields::Unit => Vec::new(),
67    };
68
69    AstStruct {
70        name: item.ident.to_string(),
71        attrs: ast_attrs_from_syn(&item.attrs),
72        fields,
73        span: span_of(item.span()),
74    }
75}
76
77pub(crate) fn ast_field_from_syn(field: &syn::Field) -> AstField {
78    AstField {
79        name: field.ident.as_ref().map(|ident| ident.to_string()),
80        ty: field.ty.to_token_stream().to_string(),
81        attrs: ast_attrs_from_syn(&field.attrs),
82        span: span_of(field.span()),
83    }
84}
85
86pub(crate) fn ast_attrs_from_syn(attrs: &[Attribute]) -> Vec<AstAttr> {
87    attrs
88        .iter()
89        .map(|attr| AstAttr {
90            path: attr.path().to_token_stream().to_string(),
91            tokens: attr_tokens(attr),
92            span: span_of(attr.span()),
93        })
94        .collect()
95}
96
97fn attr_tokens(attr: &Attribute) -> Option<String> {
98    match &attr.meta {
99        syn::Meta::Path(_) => None,
100        syn::Meta::List(list) => Some(list.tokens.to_string()),
101        syn::Meta::NameValue(name_value) => Some(name_value.value.to_token_stream().to_string()),
102    }
103}
104
105pub(crate) fn span_of(span: proc_macro2::Span) -> AstSpan {
106    let start = span.start();
107    let end = span.end();
108
109    AstSpan {
110        start_line: start.line,
111        start_column: start.column,
112        end_line: end.line,
113        end_column: end.column,
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    fn parse_file(source: &str) -> syn::File {
122        syn::parse_file(source).expect("source should parse")
123    }
124
125    #[test]
126    fn collects_struct_fields_attrs_and_spans() {
127        let file = parse_file(
128            r#"
129            #[derive(Accounts)]
130            pub struct Example<'info> {
131                #[account(mut, signer)]
132                pub authority: Signer<'info>,
133                pub count: u64,
134            }
135            "#,
136        );
137
138        let index = collect_ast_index(&file);
139        assert_eq!(index.structs.len(), 1);
140
141        let item = &index.structs[0];
142        assert_eq!(item.name, "Example");
143        assert_eq!(item.attrs.len(), 1);
144        assert_eq!(item.attrs[0].path, "derive");
145        assert_eq!(item.fields.len(), 2);
146        assert_eq!(item.fields[0].name.as_deref(), Some("authority"));
147        assert_eq!(item.fields[0].attrs.len(), 1);
148        assert_eq!(item.fields[0].attrs[0].path, "account");
149    }
150}