Skip to main content

xtee_utee_macros/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (C) 2026 KylinSoft Co., Ltd. <https://www.kylinos.cn/>
3
4extern crate proc_macro;
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::{format_ident, quote};
9use syn::{FnArg, Item, ItemFn, ItemMod, parse_macro_input, spanned::Spanned};
10
11#[proc_macro_attribute]
12pub fn ta_create(_args: TokenStream, input: TokenStream) -> TokenStream {
13    input
14}
15
16#[proc_macro_attribute]
17pub fn ta_open_session(_args: TokenStream, input: TokenStream) -> TokenStream {
18    input
19}
20
21#[proc_macro_attribute]
22pub fn ta_close_session(_args: TokenStream, input: TokenStream) -> TokenStream {
23    input
24}
25
26#[proc_macro_attribute]
27pub fn ta_destroy(_args: TokenStream, input: TokenStream) -> TokenStream {
28    input
29}
30
31#[proc_macro_attribute]
32pub fn ta_invoke_command(_args: TokenStream, input: TokenStream) -> TokenStream {
33    input
34}
35
36#[proc_macro_attribute]
37pub fn ta_acl_check(_args: TokenStream, input: TokenStream) -> TokenStream {
38    input
39}
40
41/// 聚合 TA 入口函数并生成 `TrustedApplication`(socket 模型)。用法:`xtee_ta! { #[ta_create] fn create() ... ... }`,无需额外 `mod`。
42#[proc_macro]
43pub fn xtee_ta(input: TokenStream) -> TokenStream {
44    let file = parse_macro_input!(input as syn::File);
45    match expand_xtee_ta_items(file.items) {
46        Ok(ts) => ts.into(),
47        Err(e) => e.to_compile_error().into(),
48    }
49}
50
51/// 可选:仍支持 `#[xtee_ta] mod foo { ... }`(旧写法)。
52#[proc_macro_attribute]
53pub fn xtee_ta_module(_args: TokenStream, input: TokenStream) -> TokenStream {
54    let mut item_mod = parse_macro_input!(input as ItemMod);
55    let ta_struct_ident = format_ident!("Ta");
56
57    let Some((_, items)) = &mut item_mod.content else {
58        return syn::Error::new(
59            item_mod.span(),
60            "#[xtee_ta_module] only supports inline modules",
61        )
62        .to_compile_error()
63        .into();
64    };
65
66    match expand_xtee_ta_items_mut(items, ta_struct_ident) {
67        Ok(()) => quote!(#item_mod).into(),
68        Err(e) => e.to_compile_error().into(),
69    }
70}
71
72fn expand_xtee_ta_items(items: Vec<Item>) -> Result<TokenStream2, syn::Error> {
73    let ta_struct_ident = format_ident!("Ta");
74    let mut items = items;
75    expand_xtee_ta_items_mut(&mut items, ta_struct_ident)?;
76    Ok(quote! { #(#items)* })
77}
78
79fn expand_xtee_ta_items_mut(
80    items: &mut Vec<Item>,
81    ta_struct_ident: syn::Ident,
82) -> Result<(), syn::Error> {
83    let mut create_fn: Option<ItemFn> = None;
84    let mut open_fn: Option<ItemFn> = None;
85    let mut close_fn: Option<ItemFn> = None;
86    let mut destroy_fn: Option<ItemFn> = None;
87    let mut invoke_fn: Option<ItemFn> = None;
88    let mut acl_fn: Option<ItemFn> = None;
89
90    for item in items.iter_mut() {
91        let Item::Fn(func) = item else {
92            return Err(syn::Error::new(
93                item.span(),
94                "xtee_ta! only supports `fn` items (put `use` at crate root)",
95            ));
96        };
97        let marker = extract_marker_and_strip(func);
98        let Some(marker) = marker else {
99            return Err(syn::Error::new(
100                func.span(),
101                "xtee_ta!: each `fn` must carry one of #[ta_create], #[ta_open_session], #[ta_close_session], #[ta_destroy], #[ta_invoke_command], #[ta_acl_check]",
102            ));
103        };
104        match marker.as_str() {
105            "ta_create" => {
106                if create_fn.replace(func.clone()).is_some() {
107                    return Err(syn::Error::new(
108                        func.span(),
109                        "duplicate #[ta_create] function",
110                    ));
111                }
112            }
113            "ta_open_session" => {
114                if open_fn.replace(func.clone()).is_some() {
115                    return Err(syn::Error::new(
116                        func.span(),
117                        "duplicate #[ta_open_session] function",
118                    ));
119                }
120            }
121            "ta_close_session" => {
122                if close_fn.replace(func.clone()).is_some() {
123                    return Err(syn::Error::new(
124                        func.span(),
125                        "duplicate #[ta_close_session] function",
126                    ));
127                }
128            }
129            "ta_destroy" => {
130                if destroy_fn.replace(func.clone()).is_some() {
131                    return Err(syn::Error::new(
132                        func.span(),
133                        "duplicate #[ta_destroy] function",
134                    ));
135                }
136            }
137            "ta_invoke_command" => {
138                if invoke_fn.replace(func.clone()).is_some() {
139                    return Err(syn::Error::new(
140                        func.span(),
141                        "duplicate #[ta_invoke_command] function",
142                    ));
143                }
144            }
145            "ta_acl_check" => {
146                if acl_fn.replace(func.clone()).is_some() {
147                    return Err(syn::Error::new(
148                        func.span(),
149                        "duplicate #[ta_acl_check] function",
150                    ));
151                }
152            }
153            _ => {
154                return Err(syn::Error::new(
155                    func.span(),
156                    "unknown #[ta_*] attribute for xtee_ta!",
157                ));
158            }
159        }
160    }
161
162    let Some(create_fn) = create_fn else {
163        return Err(syn::Error::new(
164            proc_macro2::Span::call_site(),
165            "missing #[ta_create] function",
166        ));
167    };
168    let Some(open_fn) = open_fn else {
169        return Err(syn::Error::new(
170            proc_macro2::Span::call_site(),
171            "missing #[ta_open_session] function",
172        ));
173    };
174    let Some(close_fn) = close_fn else {
175        return Err(syn::Error::new(
176            proc_macro2::Span::call_site(),
177            "missing #[ta_close_session] function",
178        ));
179    };
180    let Some(destroy_fn) = destroy_fn else {
181        return Err(syn::Error::new(
182            proc_macro2::Span::call_site(),
183            "missing #[ta_destroy] function",
184        ));
185    };
186    let Some(invoke_fn) = invoke_fn else {
187        return Err(syn::Error::new(
188            proc_macro2::Span::call_site(),
189            "missing #[ta_invoke_command] function",
190        ));
191    };
192
193    let create_ident = create_fn.sig.ident.clone();
194    let destroy_ident = destroy_fn.sig.ident.clone();
195
196    let (session_ctx_ty, open_call, close_call, invoke_call) =
197        build_context_and_calls(&open_fn, &close_fn, &invoke_fn)?;
198
199    let acl_check_impl = if let Some(acl_fn) = &acl_fn {
200        if acl_fn.sig.inputs.len() != 1 {
201            return Err(syn::Error::new(
202                acl_fn.sig.span(),
203                "#[ta_acl_check] expects fn(ca_auth_info: Option<&CaAuthInfo>)",
204            ));
205        }
206        let acl_ident = &acl_fn.sig.ident;
207        quote! {
208            fn acl_check(
209                &self,
210                ca_auth_info: Option<&teec_protocol::CaAuthInfo>,
211            ) -> xtee_utee::error::Result<()> {
212                __XteeIntoTaResult::into_ta_result(#acl_ident(ca_auth_info))
213            }
214        }
215    } else {
216        quote! {}
217    };
218
219    let impl_block = quote! {
220        use teec_protocol::Parameters;
221
222        pub struct #ta_struct_ident;
223
224        trait __XteeIntoTaResult {
225            fn into_ta_result(self) -> xtee_utee::error::Result<()>;
226        }
227
228        impl __XteeIntoTaResult for () {
229            fn into_ta_result(self) -> xtee_utee::error::Result<()> {
230                Ok(())
231            }
232        }
233
234        impl __XteeIntoTaResult for xtee_utee::error::Result<()> {
235            fn into_ta_result(self) -> xtee_utee::error::Result<()> {
236                self
237            }
238        }
239
240        impl xtee_utee::ta_manager::TrustedApplication for #ta_struct_ident {
241            type SessionContext = #session_ctx_ty;
242
243            fn create(&self) -> xtee_utee::error::Result<()> {
244                __XteeIntoTaResult::into_ta_result(#create_ident())
245            }
246
247            #acl_check_impl
248
249            fn open_session(
250                &self,
251                params: &mut Parameters,
252            ) -> xtee_utee::error::Result<Self::SessionContext> {
253                #open_call
254            }
255
256            fn close_session(
257                &self,
258                ctx: &mut Self::SessionContext,
259            ) -> xtee_utee::error::Result<()> {
260                __XteeIntoTaResult::into_ta_result(#close_call)
261            }
262
263            fn destroy(&self) -> xtee_utee::error::Result<()> {
264                __XteeIntoTaResult::into_ta_result(#destroy_ident())
265            }
266
267            fn invoke_command(
268                &self,
269                cmd_id: u32,
270                params: &mut Parameters,
271                ctx: &mut Self::SessionContext,
272            ) -> xtee_utee::error::Result<()> {
273                __XteeIntoTaResult::into_ta_result(#invoke_call)
274            }
275        }
276    };
277
278    let parsed: syn::File = syn::parse2(impl_block).map_err(|e| {
279        syn::Error::new(
280            proc_macro2::Span::call_site(),
281            format!("xtee_ta: failed to parse generated items: {e}"),
282        )
283    })?;
284    for item in parsed.items {
285        items.push(item);
286    }
287
288    Ok(())
289}
290
291fn extract_marker_and_strip(func: &mut ItemFn) -> Option<String> {
292    let mut marker: Option<String> = None;
293    func.attrs.retain(|attr| {
294        let Some(last) = attr.path().segments.last() else {
295            return true;
296        };
297        let name = last.ident.to_string();
298        let is_marker = matches!(
299            name.as_str(),
300            "ta_create"
301                | "ta_open_session"
302                | "ta_close_session"
303                | "ta_destroy"
304                | "ta_invoke_command"
305                | "ta_acl_check"
306        );
307        if is_marker {
308            marker = Some(name);
309            false
310        } else {
311            true
312        }
313    });
314    marker
315}
316
317fn build_context_and_calls(
318    open_fn: &ItemFn,
319    close_fn: &ItemFn,
320    invoke_fn: &ItemFn,
321) -> Result<(TokenStream2, TokenStream2, TokenStream2, TokenStream2), syn::Error> {
322    let open_arg_count = open_fn.sig.inputs.len();
323    let close_arg_count = close_fn.sig.inputs.len();
324    let invoke_arg_count = invoke_fn.sig.inputs.len();
325
326    if !(open_arg_count == 1 || open_arg_count == 2) {
327        return Err(syn::Error::new(
328            open_fn.sig.span(),
329            "#[ta_open_session] expects fn(&mut Parameters) or fn(&mut Parameters, &mut T)",
330        ));
331    }
332
333    if !(close_arg_count == 0 || close_arg_count == 1) {
334        return Err(syn::Error::new(
335            close_fn.sig.span(),
336            "#[ta_close_session] expects fn() or fn(&mut T)",
337        ));
338    }
339
340    if !(invoke_arg_count == 2 || invoke_arg_count == 3) {
341        return Err(syn::Error::new(
342            invoke_fn.sig.span(),
343            "#[ta_invoke_command] expects fn(cmd_id, &mut Parameters) or fn(&mut T, cmd_id, &mut Parameters)",
344        ));
345    }
346
347    let open_ident = &open_fn.sig.ident;
348    let close_ident = &close_fn.sig.ident;
349    let invoke_ident = &invoke_fn.sig.ident;
350
351    if open_arg_count == 1 {
352        if close_arg_count != 0 || invoke_arg_count != 2 {
353            return Err(syn::Error::new(
354                open_fn.sig.span(),
355                "no-session-context mode requires close_session() and invoke_command(cmd_id, params)",
356            ));
357        }
358        let session_ctx_ty = quote! { () };
359        let open_call = quote! {
360            __XteeIntoTaResult::into_ta_result(#open_ident(params))?;
361            Ok(())
362        };
363        let close_call = quote! { #close_ident() };
364        let invoke_call = quote! { #invoke_ident(cmd_id, params) };
365        return Ok((session_ctx_ty, open_call, close_call, invoke_call));
366    }
367
368    let ctx_ty = extract_mut_ref_type(
369        open_fn
370            .sig
371            .inputs
372            .iter()
373            .nth(1)
374            .expect("checked arg count above"),
375    )?;
376    if close_arg_count != 1 || invoke_arg_count != 3 {
377        return Err(syn::Error::new(
378            open_fn.sig.span(),
379            "session-context mode requires close_session(&mut T) and invoke_command(&mut T, cmd_id, params)",
380        ));
381    }
382    let close_ctx_ty = extract_mut_ref_type(
383        close_fn
384            .sig
385            .inputs
386            .iter()
387            .next()
388            .expect("checked arg count above"),
389    )?;
390    let invoke_ctx_ty = extract_mut_ref_type(
391        invoke_fn
392            .sig
393            .inputs
394            .iter()
395            .next()
396            .expect("checked arg count above"),
397    )?;
398    if quote!(#ctx_ty).to_string() != quote!(#close_ctx_ty).to_string()
399        || quote!(#ctx_ty).to_string() != quote!(#invoke_ctx_ty).to_string()
400    {
401        return Err(syn::Error::new(
402            open_fn.sig.span(),
403            "session context type T must be consistent across ta_open_session/ta_close_session/ta_invoke_command",
404        ));
405    }
406
407    let session_ctx_ty = quote! { #ctx_ty };
408    let open_call = quote! {
409        let mut ctx: #ctx_ty = Default::default();
410        __XteeIntoTaResult::into_ta_result(#open_ident(params, &mut ctx))?;
411        Ok(ctx)
412    };
413    let close_call = quote! { #close_ident(ctx) };
414    let invoke_call = quote! { #invoke_ident(ctx, cmd_id, params) };
415    Ok((session_ctx_ty, open_call, close_call, invoke_call))
416}
417
418fn extract_mut_ref_type(arg: &FnArg) -> Result<&syn::Type, syn::Error> {
419    if let FnArg::Typed(pat_ty) = arg {
420        if let syn::Type::Reference(type_ref) = pat_ty.ty.as_ref() {
421            if type_ref.mutability.is_some() {
422                return Ok(type_ref.elem.as_ref());
423            }
424        }
425    }
426    Err(syn::Error::new(arg.span(), "argument must be &mut T"))
427}