Skip to main content

svid_macros/
lib.rs

1//! Proc-macro derives for the `svid` crate.
2//!
3//! Re-exported from `svid` as `svid::Svid`, `svid::SvidDomain`, `svid::bridge!`.
4//! Generated code references `::svid::*` paths — depending on this crate
5//! directly will not work.
6
7use heck::ToSnakeCase;
8use proc_macro::TokenStream;
9use proc_macro2::{Span, TokenStream as TokenStream2};
10use quote::{format_ident, quote};
11use syn::{
12    parse::{Parse, ParseStream},
13    parse_macro_input, Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type,
14};
15
16// ============================================================================
17// #[derive(Svid)]
18// ============================================================================
19
20#[proc_macro_derive(Svid, attributes(svid))]
21pub fn derive_svid(input: TokenStream) -> TokenStream {
22    let input = parse_macro_input!(input as DeriveInput);
23    expand_svid(input)
24        .unwrap_or_else(Error::into_compile_error)
25        .into()
26}
27
28fn expand_svid(input: DeriveInput) -> Result<TokenStream2, Error> {
29    let enum_name = &input.ident;
30    let data = match &input.data {
31        Data::Enum(d) => d,
32        _ => {
33            return Err(Error::new_spanned(
34                &input.ident,
35                "Svid can only be derived on enums",
36            ))
37        }
38    };
39
40    if !has_repr_u8(&input.attrs) {
41        return Err(Error::new_spanned(
42            &input.ident,
43            "Svid requires `#[repr(u8)]` on the enum so variant discriminants \
44             can be cast to `u8` for the SVID tag field",
45        ));
46    }
47
48    let registry_name = parse_registry_attr(&input.attrs)?;
49
50    let mut variant_idents = Vec::with_capacity(data.variants.len());
51    for v in &data.variants {
52        if !matches!(v.fields, Fields::Unit) {
53            return Err(Error::new_spanned(
54                v,
55                "Svid variants must be unit variants like `UserId = 1`",
56            ));
57        }
58        variant_idents.push(v.ident.clone());
59    }
60
61    let id_blocks: Vec<TokenStream2> = variant_idents
62        .iter()
63        .map(|v| {
64            let marker = format_ident!("{}Marker", v);
65            quote_id_block(enum_name, v, &marker)
66        })
67        .collect();
68
69    let reserved_guards: Vec<TokenStream2> = variant_idents
70        .iter()
71        .map(|v| {
72            let msg = format!(
73                "svid: variant `{}::{}` uses tag value {} which is reserved by svid::RANDOM_ID_TAG for SvidGenerator::generate_random()",
74                enum_name, v, 127
75            );
76            quote! {
77                const _: () = {
78                    assert!(
79                        (#enum_name::#v as u8) != ::svid::RANDOM_ID_TAG,
80                        #msg
81                    );
82                };
83            }
84        })
85        .collect();
86
87    let registry_block = registry_name
88        .map(|reg| quote_registry_block(&reg, &variant_idents))
89        .unwrap_or_else(TokenStream2::new);
90
91    Ok(quote! {
92        #(#reserved_guards)*
93        #(#id_blocks)*
94        #registry_block
95    })
96}
97
98fn has_repr_u8(attrs: &[Attribute]) -> bool {
99    for attr in attrs {
100        if !attr.path().is_ident("repr") {
101            continue;
102        }
103        let mut found = false;
104        let _ = attr.parse_nested_meta(|meta| {
105            if meta.path.is_ident("u8") {
106                found = true;
107            }
108            Ok(())
109        });
110        if found {
111            return true;
112        }
113    }
114    false
115}
116
117fn parse_registry_attr(attrs: &[Attribute]) -> Result<Option<Ident>, Error> {
118    let mut registry = None;
119    for attr in attrs {
120        if !attr.path().is_ident("svid") {
121            continue;
122        }
123        attr.parse_nested_meta(|meta| {
124            if meta.path.is_ident("registry") {
125                let value = meta.value()?;
126                let id: Ident = value.parse()?;
127                registry = Some(id);
128                Ok(())
129            } else {
130                Err(meta.error("unknown svid attribute; expected `registry = Ident`"))
131            }
132        })?;
133    }
134    Ok(registry)
135}
136
137fn quote_id_block(enum_name: &Ident, v: &Ident, marker: &Ident) -> TokenStream2 {
138    let qualified_label = format!("{}::{}", enum_name, v);
139    quote! {
140        #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
141        #[cfg_attr(feature = "diesel", derive(::diesel::AsExpression, ::diesel::FromSqlRow))]
142        #[cfg_attr(feature = "diesel", diesel(sql_type = ::diesel::sql_types::BigInt))]
143        #[cfg_attr(feature = "ts", derive(::ts_rs::TS))]
144        #[cfg_attr(feature = "ts", ts(export))]
145        #[repr(transparent)]
146        pub struct #v(pub i64);
147
148        impl ::std::convert::From<i64> for #v {
149            fn from(id: i64) -> Self { Self(id) }
150        }
151
152        impl #v {
153            pub fn to_base58(&self) -> String {
154                ::svid::bs58::encode(self.0.to_be_bytes()).into_string()
155            }
156
157            pub fn from_base58(s: &str) -> ::std::result::Result<Self, String> {
158                use ::svid::SvidExt;
159                let id_val = ::svid::decode_i64_base58(s)?;
160                let expected = #enum_name::#v as u8;
161                let got = id_val.tag();
162                if got != expected {
163                    return Err(format!(
164                        "Invalid SVID tag: expected {} ({}), got {}",
165                        expected, #qualified_label, got
166                    ));
167                }
168                Ok(Self(id_val))
169            }
170
171            #[inline]
172            pub fn to_str(&self) -> String {
173                ::svid::id_to_human_readable(self.0)
174            }
175
176            #[inline]
177            pub fn from_str_id(s: &str) -> ::std::result::Result<Self, String> {
178                ::svid::human_readable_to_id_expecting(s, #enum_name::#v as u8).map(Self)
179            }
180
181            #[inline]
182            pub fn to_i64(&self) -> i64 { self.0 }
183        }
184
185        impl ::std::fmt::Display for #v {
186            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
187                write!(f, "{}", self.to_str())
188            }
189        }
190
191        impl ::std::str::FromStr for #v {
192            type Err = String;
193            fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
194                if s.len() == ::svid::HUMAN_READABLE_LEN {
195                    Self::from_str_id(s)
196                } else {
197                    Self::from_base58(s)
198                }
199            }
200        }
201
202        #[cfg(feature = "serde")]
203        impl ::serde::Serialize for #v {
204            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
205            where S: ::serde::Serializer
206            {
207                serializer.serialize_str(&self.to_str())
208            }
209        }
210
211        #[cfg(feature = "serde")]
212        impl<'de> ::serde::Deserialize<'de> for #v {
213            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
214            where D: ::serde::Deserializer<'de>
215            {
216                let s = <String as ::serde::Deserialize>::deserialize(deserializer)?;
217                if s.len() == ::svid::HUMAN_READABLE_LEN {
218                    Self::from_str_id(&s).map_err(::serde::de::Error::custom)
219                } else {
220                    Self::from_base58(&s).map_err(::serde::de::Error::custom)
221                }
222            }
223        }
224
225        #[cfg(feature = "diesel")]
226        impl ::diesel::serialize::ToSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #v {
227            fn to_sql<'b>(
228                &'b self,
229                out: &mut ::diesel::serialize::Output<'b, '_, ::diesel::pg::Pg>,
230            ) -> ::diesel::serialize::Result {
231                use ::std::io::Write;
232                out.write_all(&self.0.to_be_bytes())?;
233                Ok(::diesel::serialize::IsNull::No)
234            }
235        }
236
237        #[cfg(feature = "diesel")]
238        impl ::diesel::deserialize::FromSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #v {
239            fn from_sql(
240                bytes: <::diesel::pg::Pg as ::diesel::backend::Backend>::RawValue<'_>,
241            ) -> ::diesel::deserialize::Result<Self> {
242                let v = <i64 as ::diesel::deserialize::FromSql<
243                    ::diesel::sql_types::BigInt,
244                    ::diesel::pg::Pg,
245                >>::from_sql(bytes)?;
246                Ok(Self(v))
247            }
248        }
249
250        #[cfg(feature = "autosurgeon")]
251        impl ::autosurgeon::Reconcile for #v {
252            type Key<'a> = ::autosurgeon::reconcile::NoKey;
253            fn reconcile<R: ::autosurgeon::Reconciler>(
254                &self,
255                reconciler: R,
256            ) -> ::std::result::Result<(), R::Error> {
257                self.0.reconcile(reconciler)
258            }
259        }
260
261        #[cfg(feature = "autosurgeon")]
262        impl ::autosurgeon::Hydrate for #v {
263            fn hydrate_int(
264                i: i64,
265            ) -> ::std::result::Result<Self, ::autosurgeon::HydrateError> {
266                Ok(Self(i))
267            }
268        }
269
270        #[derive(Debug, Clone, Copy, Default)]
271        pub struct #marker;
272
273        impl ::svid::SvidKind for #marker {
274            type Id = #v;
275            const TAG: u8 = #enum_name::#v as u8;
276        }
277    }
278}
279
280fn quote_registry_block(registry: &Ident, variants: &[Ident]) -> TokenStream2 {
281    let fields: Vec<Ident> = variants
282        .iter()
283        .map(|v| Ident::new(&v.to_string().to_snake_case(), v.span()))
284        .collect();
285    let markers: Vec<Ident> = variants
286        .iter()
287        .map(|v| format_ident!("{}Marker", v))
288        .collect();
289
290    quote! {
291        #[cfg(not(target_arch = "wasm32"))]
292        pub struct #registry {
293            #( pub #fields: ::svid::IdGenerator<#markers>, )*
294        }
295
296        #[cfg(not(target_arch = "wasm32"))]
297        impl #registry {
298            pub fn new(is_client: bool) -> Self {
299                Self {
300                    #( #fields: ::svid::IdGenerator::new(is_client), )*
301                }
302            }
303
304            #[inline]
305            pub fn generate_id<T>(&self) -> T
306            where
307                Self: ::svid::GenerateId<T>,
308            {
309                <Self as ::svid::GenerateId<T>>::generate(self)
310            }
311        }
312
313        #(
314            #[cfg(not(target_arch = "wasm32"))]
315            impl ::svid::GenerateId<#variants> for #registry {
316                #[inline]
317                fn generate(&self) -> #variants {
318                    self.#fields.generate_id()
319                }
320            }
321        )*
322    }
323}
324
325// ============================================================================
326// #[derive(SvidDomain)]
327// ============================================================================
328
329#[proc_macro_derive(SvidDomain, attributes(svid))]
330pub fn derive_svid_domain(input: TokenStream) -> TokenStream {
331    let input = parse_macro_input!(input as DeriveInput);
332    expand_svid_domain(input)
333        .unwrap_or_else(Error::into_compile_error)
334        .into()
335}
336
337fn expand_svid_domain(input: DeriveInput) -> Result<TokenStream2, Error> {
338    let enum_name = &input.ident;
339    let data = match &input.data {
340        Data::Enum(d) => d,
341        _ => {
342            return Err(Error::new_spanned(
343                &input.ident,
344                "SvidDomain can only be derived on enums",
345            ))
346        }
347    };
348
349    let (error_label, tag_enum_override) = parse_svid_domain_attrs(&input.attrs)?;
350    let tag_enum = tag_enum_override.unwrap_or_else(|| Ident::new("SvidTag", Span::call_site()));
351
352    let mut variants_info: Vec<(Ident, Ident)> = Vec::with_capacity(data.variants.len());
353    let mut seen_inner: std::collections::HashMap<String, Ident> = std::collections::HashMap::new();
354    for v in &data.variants {
355        let inner = extract_single_ident_field(&v.fields)?;
356        if let Some(prev) = seen_inner.get(&inner.to_string()) {
357            return Err(Error::new_spanned(
358                &inner,
359                format!(
360                    "duplicate inner type `{}` — SvidDomain emits `From<{}> for {}`, which would conflict with the impl for the earlier variant at `{}`",
361                    inner, inner, enum_name, prev,
362                ),
363            ));
364        }
365        seen_inner.insert(inner.to_string(), inner.clone());
366        variants_info.push((v.ident.clone(), inner));
367    }
368
369    let variant_idents: Vec<&Ident> = variants_info.iter().map(|(vi, _)| vi).collect();
370    let inner_types: Vec<&Ident> = variants_info.iter().map(|(_, it)| it).collect();
371
372    // Repeated separately because quote! repetition requires each `#var` in a
373    // group to have the same number of iterations.
374    let v1 = variant_idents.clone();
375    let v2 = variant_idents.clone();
376    let v3 = variant_idents.clone();
377    let v4 = variant_idents.clone();
378    let v5 = variant_idents.clone();
379    let t1 = inner_types.clone();
380    let t2 = inner_types.clone();
381    let t3 = inner_types.clone();
382    let t4 = inner_types.clone();
383    let t5 = inner_types.clone();
384    let t6 = inner_types.clone();
385    let t7 = inner_types.clone();
386
387    Ok(quote! {
388        impl #enum_name {
389            pub fn tag(&self) -> u8 {
390                match self {
391                    #( #enum_name::#v1(_) => #tag_enum::#t1 as u8, )*
392                }
393            }
394
395            pub fn to_i64(&self) -> i64 {
396                match self {
397                    #( #enum_name::#v2(id) => id.0, )*
398                }
399            }
400
401            pub fn to_base58(&self) -> String {
402                match self {
403                    #( #enum_name::#v3(id) => id.to_base58(), )*
404                }
405            }
406
407            pub fn from_i64(id: i64) -> ::std::result::Result<Self, String> {
408                use ::svid::SvidExt;
409                let tag = id.tag();
410                #(
411                    if tag == #tag_enum::#t2 as u8 {
412                        return Ok(#enum_name::#v4(#t3(id)));
413                    }
414                )*
415                Err(format!(concat!("Invalid ", #error_label, " tag: {}"), tag))
416            }
417
418            pub fn from_base58(s: &str) -> ::std::result::Result<Self, String> {
419                Self::from_i64(::svid::decode_i64_base58(s)?)
420            }
421
422            #[inline]
423            pub fn to_str(&self) -> String {
424                ::svid::id_to_human_readable(self.to_i64())
425            }
426
427            pub fn from_str_id(s: &str) -> ::std::result::Result<Self, String> {
428                let id_val = ::svid::human_readable_to_id(s)?;
429                Self::from_i64(id_val)
430            }
431        }
432
433        impl ::std::fmt::Display for #enum_name {
434            fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
435                write!(f, "{}", self.to_str())
436            }
437        }
438
439        impl ::std::str::FromStr for #enum_name {
440            type Err = String;
441            fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
442                if s.len() == ::svid::HUMAN_READABLE_LEN {
443                    Self::from_str_id(s)
444                } else {
445                    Self::from_base58(s)
446                }
447            }
448        }
449
450        #[cfg(feature = "serde")]
451        impl ::serde::Serialize for #enum_name {
452            fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
453            where S: ::serde::Serializer
454            {
455                serializer.serialize_str(&self.to_str())
456            }
457        }
458
459        #[cfg(feature = "serde")]
460        impl<'de> ::serde::Deserialize<'de> for #enum_name {
461            fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
462            where D: ::serde::Deserializer<'de>
463            {
464                use ::std::str::FromStr;
465                let s = <String as ::serde::Deserialize>::deserialize(deserializer)?;
466                Self::from_str(&s).map_err(::serde::de::Error::custom)
467            }
468        }
469
470        #(
471            impl ::std::convert::From<#t4> for #enum_name {
472                fn from(id: #t5) -> Self { #enum_name::#v5(id) }
473            }
474
475            impl ::std::convert::TryFrom<#enum_name> for #t6 {
476                type Error = String;
477                fn try_from(val: #enum_name) -> ::std::result::Result<Self, Self::Error> {
478                    #[allow(unreachable_patterns)]
479                    match val {
480                        #enum_name::#variant_idents(id) => Ok(id),
481                        _ => Err(format!(
482                            "Expected tag for {} ({}), got tag {}",
483                            stringify!(#t7),
484                            #tag_enum::#inner_types as u8,
485                            val.tag(),
486                        )),
487                    }
488                }
489            }
490        )*
491
492        impl ::std::convert::TryFrom<i64> for #enum_name {
493            type Error = String;
494            fn try_from(id: i64) -> ::std::result::Result<Self, Self::Error> {
495                Self::from_i64(id)
496            }
497        }
498
499        impl ::std::convert::From<#enum_name> for i64 {
500            fn from(val: #enum_name) -> Self { val.to_i64() }
501        }
502
503        #[cfg(feature = "diesel")]
504        impl ::diesel::serialize::ToSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #enum_name {
505            fn to_sql<'b>(
506                &'b self,
507                out: &mut ::diesel::serialize::Output<'b, '_, ::diesel::pg::Pg>,
508            ) -> ::diesel::serialize::Result {
509                use ::std::io::Write;
510                out.write_all(&self.to_i64().to_be_bytes())?;
511                Ok(::diesel::serialize::IsNull::No)
512            }
513        }
514
515        #[cfg(feature = "diesel")]
516        impl ::diesel::deserialize::FromSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #enum_name {
517            fn from_sql(
518                bytes: <::diesel::pg::Pg as ::diesel::backend::Backend>::RawValue<'_>,
519            ) -> ::diesel::deserialize::Result<Self> {
520                let v = <i64 as ::diesel::deserialize::FromSql<
521                    ::diesel::sql_types::BigInt,
522                    ::diesel::pg::Pg,
523                >>::from_sql(bytes)?;
524                <Self as ::std::convert::TryFrom<i64>>::try_from(v)
525                    .map_err(|e: String| e.into())
526            }
527        }
528
529        #[cfg(feature = "autosurgeon")]
530        impl ::autosurgeon::Reconcile for #enum_name {
531            type Key<'a> = ::autosurgeon::reconcile::NoKey;
532            fn reconcile<R: ::autosurgeon::Reconciler>(
533                &self,
534                reconciler: R,
535            ) -> ::std::result::Result<(), R::Error> {
536                self.to_i64().reconcile(reconciler)
537            }
538        }
539
540        #[cfg(feature = "autosurgeon")]
541        impl ::autosurgeon::Hydrate for #enum_name {
542            fn hydrate_int(
543                i: i64,
544            ) -> ::std::result::Result<Self, ::autosurgeon::HydrateError> {
545                <Self as ::std::convert::TryFrom<i64>>::try_from(i)
546                    .map_err(|e| ::autosurgeon::HydrateError::unexpected(
547                        concat!("valid ", stringify!(#enum_name), " SVID tag"),
548                        e,
549                    ))
550            }
551        }
552    })
553}
554
555fn parse_svid_domain_attrs(attrs: &[Attribute]) -> Result<(LitStr, Option<Ident>), Error> {
556    let mut label = None;
557    let mut tag = None;
558    for attr in attrs {
559        if !attr.path().is_ident("svid") {
560            continue;
561        }
562        attr.parse_nested_meta(|meta| {
563            if meta.path.is_ident("error_label") {
564                let value = meta.value()?;
565                label = Some(value.parse::<LitStr>()?);
566                Ok(())
567            } else if meta.path.is_ident("tag") {
568                let value = meta.value()?;
569                tag = Some(value.parse::<Ident>()?);
570                Ok(())
571            } else {
572                Err(meta.error(
573                    "unknown svid attribute; expected `error_label = \"...\"` or `tag = Ident`",
574                ))
575            }
576        })?;
577    }
578    let label = label.ok_or_else(|| {
579        Error::new(
580            Span::call_site(),
581            "SvidDomain requires `#[svid(error_label = \"...\")]`",
582        )
583    })?;
584    Ok((label, tag))
585}
586
587fn extract_single_ident_field(fields: &Fields) -> Result<Ident, Error> {
588    const MSG: &str = "SvidDomain variants must be single-field tuple variants whose inner type is a bare ident (e.g. `Folder(FolderId)`)";
589    let unnamed = match fields {
590        Fields::Unnamed(u) if u.unnamed.len() == 1 => u,
591        _ => return Err(Error::new_spanned(fields, MSG)),
592    };
593    let ty = &unnamed.unnamed[0].ty;
594    match ty {
595        Type::Path(tp)
596            if tp.qself.is_none()
597                && tp.path.segments.len() == 1
598                && tp.path.segments[0].arguments.is_empty() =>
599        {
600            Ok(tp.path.segments[0].ident.clone())
601        }
602        _ => Err(Error::new_spanned(ty, MSG)),
603    }
604}
605
606// ============================================================================
607// svid::bridge!(Src -> Dst { Variant(IdType), ... })
608// ============================================================================
609
610struct BridgeInput {
611    src: Ident,
612    dst: Ident,
613    arms: Vec<(Ident, Ident)>,
614}
615
616impl Parse for BridgeInput {
617    fn parse(input: ParseStream) -> syn::Result<Self> {
618        let src: Ident = input.parse()?;
619        let _: Token![->] = input.parse()?;
620        let dst: Ident = input.parse()?;
621        let content;
622        syn::braced!(content in input);
623        let mut arms = Vec::new();
624        while !content.is_empty() {
625            let variant: Ident = content.parse()?;
626            let inner_content;
627            syn::parenthesized!(inner_content in content);
628            let inner: Ident = inner_content.parse()?;
629            arms.push((variant, inner));
630            if !content.is_empty() {
631                let _: Token![,] = content.parse()?;
632            }
633        }
634        Ok(BridgeInput { src, dst, arms })
635    }
636}
637
638#[proc_macro]
639pub fn bridge(input: TokenStream) -> TokenStream {
640    let BridgeInput { src, dst, arms } = parse_macro_input!(input as BridgeInput);
641    let variant_idents: Vec<&Ident> = arms.iter().map(|(v, _)| v).collect();
642    let inner_types: Vec<&Ident> = arms.iter().map(|(_, t)| t).collect();
643
644    let expanded = quote! {
645        impl ::std::convert::From<#src> for #dst {
646            fn from(val: #src) -> Self {
647                match val {
648                    #( #src::#variant_idents(id) => <#dst as ::std::convert::From<#inner_types>>::from(id), )*
649                }
650            }
651        }
652    };
653    expanded.into()
654}