u_num_it/
lib.rs

1extern crate proc_macro;
2
3use std::collections::HashMap;
4
5use proc_macro2::{Span, TokenStream, TokenTree};
6
7use quote::{quote, ToTokens};
8use syn::{
9    parse::Parse, parse_macro_input, spanned::Spanned, Expr, ExprArray, ExprMatch, Ident, Pat,
10    PatRange, RangeLimits, Token,
11};
12
13#[derive(Hash, PartialEq, Eq, Debug, Clone, Copy)]
14enum UType {
15    N,
16    P,
17    U,
18    False,
19    None,
20    Literal(isize),
21}
22
23impl std::fmt::Display for UType {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            UType::N => write!(f, "N"),
27            UType::P => write!(f, "P"),
28            UType::U => write!(f, "U"),
29            UType::False => write!(f, "False"),
30            UType::None => write!(f, ""),
31            UType::Literal(_) => write!(f, ""),
32        }
33    }
34}
35
36struct UNumIt {
37    range: Vec<isize>,
38    arms: HashMap<UType, Box<Expr>>,
39    expr: Box<Expr>,
40}
41
42fn range_boundary(val: &Option<Box<Expr>>) -> syn::Result<Option<isize>> {
43    if let Some(val) = val.clone() {
44        let string = val.to_token_stream().to_string().replace(' ', "");
45        let value = string
46            .parse::<isize>()
47            .map_err(|e| syn::Error::new(val.span(), format!("{e}: `{string}`").as_str()))?;
48
49        Ok(Some(value))
50    } else {
51        Ok(None)
52    }
53}
54
55impl Parse for UNumIt {
56    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
57        // Try to parse as array first, then fallback to range
58        let range: Vec<isize> = if input.peek(syn::token::Bracket) {
59            // Parse array syntax: [1, 2, 8, 22]
60            let array: ExprArray = input.parse()?;
61            let mut vals = array
62                .elems
63                .iter()
64                .map(|expr| {
65                    let raw = expr.to_token_stream().to_string();
66                    let norm = raw.replace([' ', '_'], "");
67                    norm.parse::<isize>().map_err(|e| {
68                        syn::Error::new(
69                            expr.span(),
70                            format!("invalid number in array: {e}: `{raw}` (normalized `{norm}`)"),
71                        )
72                    })
73                })
74                .collect::<syn::Result<Vec<isize>>>()?;
75            vals.sort();
76            vals.dedup();
77            vals
78        } else {
79            // Parse range syntax: 1..10 or 1..=10
80            let range: PatRange = input.parse()?;
81            let start = range_boundary(&range.start)?.unwrap_or(0);
82            let end = range_boundary(&range.end)?.unwrap_or(isize::MAX);
83            match &range.limits {
84                RangeLimits::HalfOpen(_) => (start..end).collect(),
85                RangeLimits::Closed(_) => (start..=end).collect(),
86            }
87        };
88
89        input.parse::<Token![,]>()?;
90        let matcher: ExprMatch = input.parse()?;
91
92        let mut arms = HashMap::new();
93
94        for arm in matcher.arms.iter() {
95            let u_type = match &arm.pat {
96                Pat::Ident(t) => match t.ident.to_token_stream().to_string().as_str() {
97                    "N" => UType::N,
98                    "P" => UType::P,
99                    "U" => UType::U,
100                    "False" => UType::False,
101                    _ => {
102                        return Err(syn::Error::new(
103                            t.span(),
104                            "expected idents N | P | U | False | _",
105                        ))
106                    }
107                },
108                Pat::Lit(lit_expr) => {
109                    // Parse literal numbers in match arms (normalize spaces & underscores; base-10 only)
110                    let raw = lit_expr.to_token_stream().to_string();
111                    let norm = raw.replace([' ', '_'], "");
112                    if norm.starts_with("0x") || norm.starts_with("0b") || norm.starts_with("0o") {
113                        return Err(syn::Error::new(
114                            lit_expr.span(),
115                            format!("unsupported non-decimal literal `{raw}`"),
116                        ));
117                    }
118                    let value = norm.parse::<isize>().map_err(|e| {
119                        syn::Error::new(
120                            lit_expr.span(),
121                            format!("invalid literal: {e}: `{raw}` (normalized `{norm}`)"),
122                        )
123                    })?;
124                    UType::Literal(value)
125                }
126                Pat::Wild(_) => UType::None,
127                _ => return Err(syn::Error::new(arm.pat.span(), "expected ident")),
128            };
129            let arm_expr = arm.body.clone();
130            if arms.insert(u_type, arm_expr.clone()).is_some() {
131                return Err(syn::Error::new(arm_expr.span(), "duplicate type"));
132            }
133        }
134
135        if arms.get(&UType::P).and(arms.get(&UType::U)).is_some() {
136            return Err(syn::Error::new(
137                matcher.span(),
138                "ambiguous type, don't use P and U in the same macro call",
139            ));
140        }
141
142        // Check for conflict between literal 0 and False (they represent the same value in typenum)
143        if arms
144            .get(&UType::Literal(0))
145            .and(arms.get(&UType::False))
146            .is_some()
147        {
148            return Err(syn::Error::new(
149                matcher.span(),
150                "ambiguous type, don't use literal 0 and False in the same macro call (they represent the same value)",
151            ));
152        }
153
154        let expr = matcher.expr;
155
156        Ok(UNumIt { range, arms, expr })
157    }
158}
159
160fn make_match_arm(i: &isize, body: &Expr, u_type: UType) -> TokenStream {
161    let match_expr = quote!(#i);
162
163    // Determine the typenum type for all cases
164    let i_str = if *i != 0 {
165        i.abs().to_string()
166    } else {
167        Default::default()
168    };
169
170    // Determine the type variant based on UType
171    let u_type_for_typenum = match u_type {
172        UType::Literal(0) => UType::False,
173        UType::Literal(val) if val < 0 => UType::N,
174        UType::Literal(val) if val > 0 => UType::P,
175        _ => u_type,
176    };
177
178    let typenum_type = TokenTree::Ident(Ident::new(
179        format!("{}{}", u_type_for_typenum, i_str).as_str(),
180        Span::mixed_site(),
181    ));
182    let type_variant = quote!(typenum::consts::#typenum_type);
183
184    // All match arms get NumType and use body as-is (no pattern replacement)
185    let body_tokens = body.to_token_stream();
186
187    quote! {
188        #match_expr => {
189            type NumType = #type_variant;
190            #body_tokens
191        },
192    }
193}
194
195/// matches `typenum::consts` in a given range or array
196///
197/// use with an open or closed range, or an array of arbitrary numbers
198///
199/// use `P` | `N` | `U` | `False` | `_` or literals `1` | `-1` as match arms
200///
201/// a `NumType` type alias is available in each match arm,
202/// resolving to the specific typenum type for that value.
203/// Use `NumType` to reference the resolved type in the match arm body.
204///
205/// ## Example (range)
206///
207/// ```
208/// let x = 3;
209///
210/// u_num_it::u_num_it!(1..10, match x {
211///     U => {
212///         // NumType is typenum::consts::U3 when x=3
213///         let val = NumType::new();
214///         println!("{:?}", val);
215///         // UInt { msb: UInt { msb: UTerm, lsb: B1 }, lsb: B1 }
216///
217///         use typenum::ToInt;
218///         let num: usize = NumType::to_int();
219///         assert_eq!(num, 3);
220///     }
221/// })
222/// ```
223///
224/// ## Example (array)
225///
226/// ```
227/// let x = 8;
228///
229/// u_num_it::u_num_it!([1, 2, 8, 22], match x {
230///     P => {
231///         // NumType is typenum::consts::P8 when x=8
232///         use typenum::ToInt;
233///         let num: i32 = NumType::to_int();
234///         assert_eq!(num, 8);
235///     }
236/// })
237/// ```
238///
239/// ## Example (negative literal)
240/// ```
241/// let result = u_num_it::u_num_it!(-5..=5, match -3 {
242///     -3 => {
243///         use typenum::ToInt;
244///         let n: i32 = NumType::to_int();
245///         assert_eq!(n, -3);
246///         "ok"
247///     },
248///     N => "neg",
249///     _ => "other"
250/// });
251/// assert_eq!(result, "ok");
252/// ```
253#[proc_macro]
254pub fn u_num_it(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
255    let UNumIt { range, arms, expr } = parse_macro_input!(tokens as UNumIt);
256
257    let pos_u = arms.contains_key(&UType::U);
258
259    let expanded_arms = range.iter().filter_map(|i| {
260        // First check if there's a specific literal match for this number
261        if let Some(body) = arms.get(&UType::Literal(*i)) {
262            return Some(make_match_arm(i, body, UType::Literal(*i)));
263        }
264
265        // Otherwise, use the general type patterns
266        match i {
267            0 => arms
268                .get(&UType::False)
269                .map(|body| make_match_arm(i, body, UType::False)),
270            i if *i < 0 => arms
271                .get(&UType::N)
272                .map(|body| make_match_arm(i, body, UType::N)),
273            i if *i > 0 => {
274                if pos_u {
275                    arms.get(&UType::U)
276                        .map(|body| make_match_arm(i, body, UType::U))
277                } else {
278                    arms.get(&UType::P)
279                        .map(|body| make_match_arm(i, body, UType::P))
280                }
281            }
282            _ => unreachable!(),
283        }
284    });
285
286    let fallback = arms
287        .get(&UType::None)
288        .map(|body| {
289            quote! {
290                _ => {
291                    #body
292                },
293            }
294        })
295        .unwrap_or_else(|| {
296            let first = range.first().unwrap_or(&0);
297            let last = range.last().unwrap_or(&0);
298            quote! {
299                i => unreachable!("{i} not in range {}..={}", #first, #last),
300            }
301        });
302
303    let expanded = quote! {
304        match #expr {
305            #(#expanded_arms)*
306            #fallback
307        }
308    };
309
310    proc_macro::TokenStream::from(expanded)
311}