static_key_internal/
lib.rs

1use std::cmp::Ordering;
2
3use proc_macro2::{Literal, Span};
4use syn::{Arm, LitStr, Path, Token, Type, parse::Parse, parse_macro_input};
5
6struct StaticMatch {
7    target_arch: LitStr,
8    crate_path: Path,
9    key: Path,
10    ty: Type,
11    semicolon: Token![;],
12    arms: Vec<Arm>,
13}
14
15impl Parse for StaticMatch {
16    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
17        let target_arch = input.parse()?;
18        let _: Token![,] = input.parse()?;
19        let crate_path = input.parse()?;
20        let _: Token![;] = input.parse()?;
21        let key = input.parse()?;
22        let _: Token![:] = input.parse()?;
23        let ty = input.parse()?;
24        let semicolon = input.parse()?;
25        let mut arms = Vec::new();
26        while !input.is_empty() {
27            arms.push(input.parse()?);
28        }
29        Ok(Self {
30            target_arch,
31            crate_path,
32            key,
33            ty,
34            semicolon,
35            arms,
36        })
37    }
38}
39
40#[proc_macro]
41pub fn parse_static_match(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
42    let StaticMatch {
43        target_arch,
44        crate_path,
45        key,
46        ty,
47        semicolon,
48        mut arms,
49    } = parse_macro_input!(tokens as StaticMatch);
50
51    if arms.is_empty() {
52        return syn::Error::new(semicolon.span, "static_match! cannot be used without arms")
53            .into_compile_error()
54            .into();
55    }
56
57    // Extract the most likely arm, which will be used as the fallthrough arm.
58    let likely_arm = arms
59        .iter_mut()
60        .position(|x| {
61            let mut found = false;
62            // Find the `#[likely]` attribute and remove it.
63            x.attrs.retain(|x| {
64                if x.path().is_ident("likely") {
65                    found = true;
66                    return false;
67                }
68                true
69            });
70            found
71        })
72        .unwrap_or(arms.len() - 1);
73
74    // Generate a matcher function that matches on a `&ty` and returns an index of the match arm.
75    // For the `likely` arm, it needs to be encoded as `usize::MAX` instead.
76    let matcher_arms: Vec<_> = arms
77        .iter()
78        .enumerate()
79        .map(|(idx, x)| {
80            let mapped_idx = match idx.cmp(&likely_arm) {
81                Ordering::Less => idx as isize,
82                Ordering::Equal => -1,
83                Ordering::Greater => idx as isize - 1,
84            };
85            Arm {
86                body: Box::new(
87                    syn::ExprLit {
88                        attrs: Vec::new(),
89                        lit: syn::Lit::new(Literal::isize_unsuffixed(mapped_idx)),
90                    }
91                    .into(),
92                ),
93                comma: Some(Token![,](Span::mixed_site())),
94                ..x.clone()
95            }
96        })
97        .collect();
98    let matcher_fn = quote::quote_spanned! { Span::mixed_site() =>
99        fn matcher(value: &#ty) -> isize {
100            match value {
101                #(#matcher_arms)*
102            }
103        }
104    };
105
106    let fallback_body = arms.remove(likely_arm).body;
107    let arms_len = arms.len();
108    let label_bodies: Vec<_> = arms.into_iter().map(|arm| arm.body).collect();
109
110    match target_arch.value().as_str() {
111        "x86_64" => {
112            let label_templates = (0..arms_len).map(|_| ".4byte {} - (2b + 5)");
113
114            quote::quote_spanned! { Span::mixed_site() => 'label:{unsafe{
115                #matcher_fn
116
117                // Check the type of key and expected type matches.
118                const _: *const #crate_path::StaticKey::<#ty> = &raw const #key;
119
120                ::core::arch::asm!(
121                    // Aligns the start to 8 byte boundary if doing so only require 1 bytes.
122                    // This means that the start address % 8 will be 0~6 (but never 7).
123                    // This ensures that we can at least atomically replace 2 bytes at a time.
124                    // We pad using 0x2E which is the CS prefix, so we still have 1 single instruction.
125                    ".p2align 3,0x2e,1",
126                    "2:",
127                    // 5 bytes
128                    "ud2; ud2; nop",
129                    r#".pushsection .data.static_match.jump_table"#,
130                    ".p2align 3",
131                    "3:",
132                    ".8byte {key}",
133                    ".8byte 2b",
134                    ".8byte {matcher}",
135                    ".8byte 0",
136                    #(#label_templates,)*
137                    ".popsection",
138                    r#".pushsection .text.startup.static_match.init"#,
139                    "4:",
140                    "lea rdi, [rip + 3b]",
141                    "jmp {register}",
142                    ".popsection",
143                    ".pushsection .init_array",
144                    ".8byte 4b",
145                    ".popsection",
146                    #(
147                        label { break 'label { match () { () => #label_bodies } }; },
148                    )*
149                    key = sym #key,
150                    matcher = sym matcher,
151                    register = sym #crate_path::CallSite::<#ty>::register,
152                    options(nomem, nostack)
153                );
154
155                match () { () => #fallback_body }
156            }}}
157        }
158        "riscv64" => {
159            let label_templates = (0..arms_len).map(|_| ".4byte {} - 2b");
160
161            quote::quote_spanned! { Span::mixed_site() => 'label:{unsafe{
162                #matcher_fn
163
164                // Check the type of key and expected type matches.
165                const _: *const #crate_path::StaticKey::<#ty> = &raw const #key;
166
167                ::core::arch::asm!(
168                    // When RISC-V QEMU runs userspace level emulation, a thread enters a infinite loop and
169                    // cannot leave it even if the instruction is modified and `sync_core()`.
170                    // Align the instruction so instruction can be atomically replaced without need to add
171                    // an infinite loop.
172                    ".p2align 2",
173                    "2:",
174                    ".option push",
175                    ".option norelax",
176                    ".option norvc",
177                    "unimp",
178                    ".option pop",
179                    r#".pushsection .data.static_match.jump_table"#,
180                    ".p2align 3",
181                    "3:",
182                    ".8byte {key}",
183                    ".8byte 2b",
184                    ".8byte {matcher}",
185                    ".8byte 0",
186                    #(#label_templates,)*
187                    ".popsection",
188                    r#".pushsection .text.startup.static_match.init"#,
189                    "4:",
190                    "la a0, 3b",
191                    "j {register}",
192                    ".popsection",
193                    ".pushsection .init_array",
194                    ".8byte 4b",
195                    ".popsection",
196                    #(
197                        label { break 'label { match () { () => #label_bodies } }; },
198                    )*
199                    key = sym #key,
200                    matcher = sym matcher,
201                    register = sym #crate_path::CallSite::<#ty>::register,
202                    options(nomem, nostack)
203                );
204
205                match () { () => #fallback_body }
206            }}}
207        }
208        _ => {
209            unimplemented!();
210        }
211    }
212    .into()
213}