1use heck::ToSnakeCase;
8use proc_macro::TokenStream;
9use proc_macro2::{Span, TokenStream as TokenStream2};
10use quote::{format_ident, quote};
11use syn::{
12 parse::{Parse, ParseStream},
13 parse_macro_input, Attribute, Data, DeriveInput, Error, Fields, Ident, LitStr, Token, Type,
14};
15
16#[proc_macro_derive(Svid, attributes(svid))]
21pub fn derive_svid(input: TokenStream) -> TokenStream {
22 let input = parse_macro_input!(input as DeriveInput);
23 expand_svid(input)
24 .unwrap_or_else(Error::into_compile_error)
25 .into()
26}
27
28fn expand_svid(input: DeriveInput) -> Result<TokenStream2, Error> {
29 let enum_name = &input.ident;
30 let data = match &input.data {
31 Data::Enum(d) => d,
32 _ => {
33 return Err(Error::new_spanned(
34 &input.ident,
35 "Svid can only be derived on enums",
36 ))
37 }
38 };
39
40 if !has_repr_u8(&input.attrs) {
41 return Err(Error::new_spanned(
42 &input.ident,
43 "Svid requires `#[repr(u8)]` on the enum so variant discriminants \
44 can be cast to `u8` for the SVID tag field",
45 ));
46 }
47
48 let registry_name = parse_registry_attr(&input.attrs)?;
49
50 let mut variant_idents = Vec::with_capacity(data.variants.len());
51 for v in &data.variants {
52 if !matches!(v.fields, Fields::Unit) {
53 return Err(Error::new_spanned(
54 v,
55 "Svid variants must be unit variants like `UserId = 1`",
56 ));
57 }
58 variant_idents.push(v.ident.clone());
59 }
60
61 let id_blocks: Vec<TokenStream2> = variant_idents
62 .iter()
63 .map(|v| {
64 let marker = format_ident!("{}Marker", v);
65 quote_id_block(enum_name, v, &marker)
66 })
67 .collect();
68
69 let reserved_guards: Vec<TokenStream2> = variant_idents
70 .iter()
71 .map(|v| {
72 let msg = format!(
73 "svid: variant `{}::{}` uses tag value {} which is reserved by svid::RANDOM_ID_TAG for SvidGenerator::generate_random()",
74 enum_name, v, 127
75 );
76 quote! {
77 const _: () = {
78 assert!(
79 (#enum_name::#v as u8) != ::svid::RANDOM_ID_TAG,
80 #msg
81 );
82 };
83 }
84 })
85 .collect();
86
87 let registry_block = registry_name
88 .map(|reg| quote_registry_block(®, &variant_idents))
89 .unwrap_or_else(TokenStream2::new);
90
91 Ok(quote! {
92 #(#reserved_guards)*
93 #(#id_blocks)*
94 #registry_block
95 })
96}
97
98fn has_repr_u8(attrs: &[Attribute]) -> bool {
99 for attr in attrs {
100 if !attr.path().is_ident("repr") {
101 continue;
102 }
103 let mut found = false;
104 let _ = attr.parse_nested_meta(|meta| {
105 if meta.path.is_ident("u8") {
106 found = true;
107 }
108 Ok(())
109 });
110 if found {
111 return true;
112 }
113 }
114 false
115}
116
117fn parse_registry_attr(attrs: &[Attribute]) -> Result<Option<Ident>, Error> {
118 let mut registry = None;
119 for attr in attrs {
120 if !attr.path().is_ident("svid") {
121 continue;
122 }
123 attr.parse_nested_meta(|meta| {
124 if meta.path.is_ident("registry") {
125 let value = meta.value()?;
126 let id: Ident = value.parse()?;
127 registry = Some(id);
128 Ok(())
129 } else {
130 Err(meta.error("unknown svid attribute; expected `registry = Ident`"))
131 }
132 })?;
133 }
134 Ok(registry)
135}
136
137fn quote_id_block(enum_name: &Ident, v: &Ident, marker: &Ident) -> TokenStream2 {
138 let qualified_label = format!("{}::{}", enum_name, v);
139 quote! {
140 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
141 #[cfg_attr(feature = "diesel", derive(::diesel::AsExpression, ::diesel::FromSqlRow))]
142 #[cfg_attr(feature = "diesel", diesel(sql_type = ::diesel::sql_types::BigInt))]
143 #[cfg_attr(feature = "ts", derive(::ts_rs::TS))]
144 #[cfg_attr(feature = "ts", ts(export))]
145 #[repr(transparent)]
146 pub struct #v(pub i64);
147
148 impl ::std::convert::From<i64> for #v {
149 fn from(id: i64) -> Self { Self(id) }
150 }
151
152 impl #v {
153 pub fn to_base58(&self) -> String {
154 ::svid::bs58::encode(self.0.to_be_bytes()).into_string()
155 }
156
157 pub fn from_base58(s: &str) -> ::std::result::Result<Self, String> {
158 use ::svid::SvidExt;
159 let id_val = ::svid::decode_i64_base58(s)?;
160 let expected = #enum_name::#v as u8;
161 let got = id_val.tag();
162 if got != expected {
163 return Err(format!(
164 "Invalid SVID tag: expected {} ({}), got {}",
165 expected, #qualified_label, got
166 ));
167 }
168 Ok(Self(id_val))
169 }
170
171 #[inline]
172 pub fn to_str(&self) -> String {
173 ::svid::id_to_human_readable(self.0)
174 }
175
176 #[inline]
177 pub fn from_str_id(s: &str) -> ::std::result::Result<Self, String> {
178 ::svid::human_readable_to_id_expecting(s, #enum_name::#v as u8).map(Self)
179 }
180
181 #[inline]
182 pub fn to_i64(&self) -> i64 { self.0 }
183 }
184
185 impl ::std::fmt::Display for #v {
186 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
187 write!(f, "{}", self.to_str())
188 }
189 }
190
191 impl ::std::str::FromStr for #v {
192 type Err = String;
193 fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
194 if s.len() == ::svid::HUMAN_READABLE_LEN {
195 Self::from_str_id(s)
196 } else {
197 Self::from_base58(s)
198 }
199 }
200 }
201
202 #[cfg(feature = "serde")]
203 impl ::serde::Serialize for #v {
204 fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
205 where S: ::serde::Serializer
206 {
207 serializer.serialize_str(&self.to_str())
208 }
209 }
210
211 #[cfg(feature = "serde")]
212 impl<'de> ::serde::Deserialize<'de> for #v {
213 fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
214 where D: ::serde::Deserializer<'de>
215 {
216 let s = <String as ::serde::Deserialize>::deserialize(deserializer)?;
217 if s.len() == ::svid::HUMAN_READABLE_LEN {
218 Self::from_str_id(&s).map_err(::serde::de::Error::custom)
219 } else {
220 Self::from_base58(&s).map_err(::serde::de::Error::custom)
221 }
222 }
223 }
224
225 #[cfg(feature = "diesel")]
226 impl ::diesel::serialize::ToSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #v {
227 fn to_sql<'b>(
228 &'b self,
229 out: &mut ::diesel::serialize::Output<'b, '_, ::diesel::pg::Pg>,
230 ) -> ::diesel::serialize::Result {
231 use ::std::io::Write;
232 out.write_all(&self.0.to_be_bytes())?;
233 Ok(::diesel::serialize::IsNull::No)
234 }
235 }
236
237 #[cfg(feature = "diesel")]
238 impl ::diesel::deserialize::FromSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #v {
239 fn from_sql(
240 bytes: <::diesel::pg::Pg as ::diesel::backend::Backend>::RawValue<'_>,
241 ) -> ::diesel::deserialize::Result<Self> {
242 let v = <i64 as ::diesel::deserialize::FromSql<
243 ::diesel::sql_types::BigInt,
244 ::diesel::pg::Pg,
245 >>::from_sql(bytes)?;
246 Ok(Self(v))
247 }
248 }
249
250 #[cfg(feature = "autosurgeon")]
251 impl ::autosurgeon::Reconcile for #v {
252 type Key<'a> = ::autosurgeon::reconcile::NoKey;
253 fn reconcile<R: ::autosurgeon::Reconciler>(
254 &self,
255 reconciler: R,
256 ) -> ::std::result::Result<(), R::Error> {
257 self.0.reconcile(reconciler)
258 }
259 }
260
261 #[cfg(feature = "autosurgeon")]
262 impl ::autosurgeon::Hydrate for #v {
263 fn hydrate_int(
264 i: i64,
265 ) -> ::std::result::Result<Self, ::autosurgeon::HydrateError> {
266 Ok(Self(i))
267 }
268 }
269
270 #[derive(Debug, Clone, Copy, Default)]
271 pub struct #marker;
272
273 impl ::svid::SvidKind for #marker {
274 type Id = #v;
275 const TAG: u8 = #enum_name::#v as u8;
276 }
277 }
278}
279
280fn quote_registry_block(registry: &Ident, variants: &[Ident]) -> TokenStream2 {
281 let fields: Vec<Ident> = variants
282 .iter()
283 .map(|v| Ident::new(&v.to_string().to_snake_case(), v.span()))
284 .collect();
285 let markers: Vec<Ident> = variants
286 .iter()
287 .map(|v| format_ident!("{}Marker", v))
288 .collect();
289
290 quote! {
291 #[cfg(not(target_arch = "wasm32"))]
292 pub struct #registry {
293 #( pub #fields: ::svid::IdGenerator<#markers>, )*
294 }
295
296 #[cfg(not(target_arch = "wasm32"))]
297 impl #registry {
298 pub fn new(is_client: bool) -> Self {
299 Self {
300 #( #fields: ::svid::IdGenerator::new(is_client), )*
301 }
302 }
303
304 #[inline]
305 pub fn generate_id<T>(&self) -> T
306 where
307 Self: ::svid::GenerateId<T>,
308 {
309 <Self as ::svid::GenerateId<T>>::generate(self)
310 }
311 }
312
313 #(
314 #[cfg(not(target_arch = "wasm32"))]
315 impl ::svid::GenerateId<#variants> for #registry {
316 #[inline]
317 fn generate(&self) -> #variants {
318 self.#fields.generate_id()
319 }
320 }
321 )*
322 }
323}
324
325#[proc_macro_derive(SvidDomain, attributes(svid))]
330pub fn derive_svid_domain(input: TokenStream) -> TokenStream {
331 let input = parse_macro_input!(input as DeriveInput);
332 expand_svid_domain(input)
333 .unwrap_or_else(Error::into_compile_error)
334 .into()
335}
336
337fn expand_svid_domain(input: DeriveInput) -> Result<TokenStream2, Error> {
338 let enum_name = &input.ident;
339 let data = match &input.data {
340 Data::Enum(d) => d,
341 _ => {
342 return Err(Error::new_spanned(
343 &input.ident,
344 "SvidDomain can only be derived on enums",
345 ))
346 }
347 };
348
349 let (error_label, tag_enum_override) = parse_svid_domain_attrs(&input.attrs)?;
350 let tag_enum = tag_enum_override.unwrap_or_else(|| Ident::new("SvidTag", Span::call_site()));
351
352 let mut variants_info: Vec<(Ident, Ident)> = Vec::with_capacity(data.variants.len());
353 let mut seen_inner: std::collections::HashMap<String, Ident> = std::collections::HashMap::new();
354 for v in &data.variants {
355 let inner = extract_single_ident_field(&v.fields)?;
356 if let Some(prev) = seen_inner.get(&inner.to_string()) {
357 return Err(Error::new_spanned(
358 &inner,
359 format!(
360 "duplicate inner type `{}` — SvidDomain emits `From<{}> for {}`, which would conflict with the impl for the earlier variant at `{}`",
361 inner, inner, enum_name, prev,
362 ),
363 ));
364 }
365 seen_inner.insert(inner.to_string(), inner.clone());
366 variants_info.push((v.ident.clone(), inner));
367 }
368
369 let variant_idents: Vec<&Ident> = variants_info.iter().map(|(vi, _)| vi).collect();
370 let inner_types: Vec<&Ident> = variants_info.iter().map(|(_, it)| it).collect();
371
372 let v1 = variant_idents.clone();
375 let v2 = variant_idents.clone();
376 let v3 = variant_idents.clone();
377 let v4 = variant_idents.clone();
378 let v5 = variant_idents.clone();
379 let t1 = inner_types.clone();
380 let t2 = inner_types.clone();
381 let t3 = inner_types.clone();
382 let t4 = inner_types.clone();
383 let t5 = inner_types.clone();
384 let t6 = inner_types.clone();
385 let t7 = inner_types.clone();
386
387 Ok(quote! {
388 impl #enum_name {
389 pub fn tag(&self) -> u8 {
390 match self {
391 #( #enum_name::#v1(_) => #tag_enum::#t1 as u8, )*
392 }
393 }
394
395 pub fn to_i64(&self) -> i64 {
396 match self {
397 #( #enum_name::#v2(id) => id.0, )*
398 }
399 }
400
401 pub fn to_base58(&self) -> String {
402 match self {
403 #( #enum_name::#v3(id) => id.to_base58(), )*
404 }
405 }
406
407 pub fn from_i64(id: i64) -> ::std::result::Result<Self, String> {
408 use ::svid::SvidExt;
409 let tag = id.tag();
410 #(
411 if tag == #tag_enum::#t2 as u8 {
412 return Ok(#enum_name::#v4(#t3(id)));
413 }
414 )*
415 Err(format!(concat!("Invalid ", #error_label, " tag: {}"), tag))
416 }
417
418 pub fn from_base58(s: &str) -> ::std::result::Result<Self, String> {
419 Self::from_i64(::svid::decode_i64_base58(s)?)
420 }
421
422 #[inline]
423 pub fn to_str(&self) -> String {
424 ::svid::id_to_human_readable(self.to_i64())
425 }
426
427 pub fn from_str_id(s: &str) -> ::std::result::Result<Self, String> {
428 let id_val = ::svid::human_readable_to_id(s)?;
429 Self::from_i64(id_val)
430 }
431 }
432
433 impl ::std::fmt::Display for #enum_name {
434 fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
435 write!(f, "{}", self.to_str())
436 }
437 }
438
439 impl ::std::str::FromStr for #enum_name {
440 type Err = String;
441 fn from_str(s: &str) -> ::std::result::Result<Self, Self::Err> {
442 if s.len() == ::svid::HUMAN_READABLE_LEN {
443 Self::from_str_id(s)
444 } else {
445 Self::from_base58(s)
446 }
447 }
448 }
449
450 #[cfg(feature = "serde")]
451 impl ::serde::Serialize for #enum_name {
452 fn serialize<S>(&self, serializer: S) -> ::std::result::Result<S::Ok, S::Error>
453 where S: ::serde::Serializer
454 {
455 serializer.serialize_str(&self.to_str())
456 }
457 }
458
459 #[cfg(feature = "serde")]
460 impl<'de> ::serde::Deserialize<'de> for #enum_name {
461 fn deserialize<D>(deserializer: D) -> ::std::result::Result<Self, D::Error>
462 where D: ::serde::Deserializer<'de>
463 {
464 use ::std::str::FromStr;
465 let s = <String as ::serde::Deserialize>::deserialize(deserializer)?;
466 Self::from_str(&s).map_err(::serde::de::Error::custom)
467 }
468 }
469
470 #(
471 impl ::std::convert::From<#t4> for #enum_name {
472 fn from(id: #t5) -> Self { #enum_name::#v5(id) }
473 }
474
475 impl ::std::convert::TryFrom<#enum_name> for #t6 {
476 type Error = String;
477 fn try_from(val: #enum_name) -> ::std::result::Result<Self, Self::Error> {
478 #[allow(unreachable_patterns)]
479 match val {
480 #enum_name::#variant_idents(id) => Ok(id),
481 _ => Err(format!(
482 "Expected tag for {} ({}), got tag {}",
483 stringify!(#t7),
484 #tag_enum::#inner_types as u8,
485 val.tag(),
486 )),
487 }
488 }
489 }
490 )*
491
492 impl ::std::convert::TryFrom<i64> for #enum_name {
493 type Error = String;
494 fn try_from(id: i64) -> ::std::result::Result<Self, Self::Error> {
495 Self::from_i64(id)
496 }
497 }
498
499 impl ::std::convert::From<#enum_name> for i64 {
500 fn from(val: #enum_name) -> Self { val.to_i64() }
501 }
502
503 #[cfg(feature = "diesel")]
504 impl ::diesel::serialize::ToSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #enum_name {
505 fn to_sql<'b>(
506 &'b self,
507 out: &mut ::diesel::serialize::Output<'b, '_, ::diesel::pg::Pg>,
508 ) -> ::diesel::serialize::Result {
509 use ::std::io::Write;
510 out.write_all(&self.to_i64().to_be_bytes())?;
511 Ok(::diesel::serialize::IsNull::No)
512 }
513 }
514
515 #[cfg(feature = "diesel")]
516 impl ::diesel::deserialize::FromSql<::diesel::sql_types::BigInt, ::diesel::pg::Pg> for #enum_name {
517 fn from_sql(
518 bytes: <::diesel::pg::Pg as ::diesel::backend::Backend>::RawValue<'_>,
519 ) -> ::diesel::deserialize::Result<Self> {
520 let v = <i64 as ::diesel::deserialize::FromSql<
521 ::diesel::sql_types::BigInt,
522 ::diesel::pg::Pg,
523 >>::from_sql(bytes)?;
524 <Self as ::std::convert::TryFrom<i64>>::try_from(v)
525 .map_err(|e: String| e.into())
526 }
527 }
528
529 #[cfg(feature = "autosurgeon")]
530 impl ::autosurgeon::Reconcile for #enum_name {
531 type Key<'a> = ::autosurgeon::reconcile::NoKey;
532 fn reconcile<R: ::autosurgeon::Reconciler>(
533 &self,
534 reconciler: R,
535 ) -> ::std::result::Result<(), R::Error> {
536 self.to_i64().reconcile(reconciler)
537 }
538 }
539
540 #[cfg(feature = "autosurgeon")]
541 impl ::autosurgeon::Hydrate for #enum_name {
542 fn hydrate_int(
543 i: i64,
544 ) -> ::std::result::Result<Self, ::autosurgeon::HydrateError> {
545 <Self as ::std::convert::TryFrom<i64>>::try_from(i)
546 .map_err(|e| ::autosurgeon::HydrateError::unexpected(
547 concat!("valid ", stringify!(#enum_name), " SVID tag"),
548 e,
549 ))
550 }
551 }
552 })
553}
554
555fn parse_svid_domain_attrs(attrs: &[Attribute]) -> Result<(LitStr, Option<Ident>), Error> {
556 let mut label = None;
557 let mut tag = None;
558 for attr in attrs {
559 if !attr.path().is_ident("svid") {
560 continue;
561 }
562 attr.parse_nested_meta(|meta| {
563 if meta.path.is_ident("error_label") {
564 let value = meta.value()?;
565 label = Some(value.parse::<LitStr>()?);
566 Ok(())
567 } else if meta.path.is_ident("tag") {
568 let value = meta.value()?;
569 tag = Some(value.parse::<Ident>()?);
570 Ok(())
571 } else {
572 Err(meta.error(
573 "unknown svid attribute; expected `error_label = \"...\"` or `tag = Ident`",
574 ))
575 }
576 })?;
577 }
578 let label = label.ok_or_else(|| {
579 Error::new(
580 Span::call_site(),
581 "SvidDomain requires `#[svid(error_label = \"...\")]`",
582 )
583 })?;
584 Ok((label, tag))
585}
586
587fn extract_single_ident_field(fields: &Fields) -> Result<Ident, Error> {
588 const MSG: &str = "SvidDomain variants must be single-field tuple variants whose inner type is a bare ident (e.g. `Folder(FolderId)`)";
589 let unnamed = match fields {
590 Fields::Unnamed(u) if u.unnamed.len() == 1 => u,
591 _ => return Err(Error::new_spanned(fields, MSG)),
592 };
593 let ty = &unnamed.unnamed[0].ty;
594 match ty {
595 Type::Path(tp)
596 if tp.qself.is_none()
597 && tp.path.segments.len() == 1
598 && tp.path.segments[0].arguments.is_empty() =>
599 {
600 Ok(tp.path.segments[0].ident.clone())
601 }
602 _ => Err(Error::new_spanned(ty, MSG)),
603 }
604}
605
606struct BridgeInput {
611 src: Ident,
612 dst: Ident,
613 arms: Vec<(Ident, Ident)>,
614}
615
616impl Parse for BridgeInput {
617 fn parse(input: ParseStream) -> syn::Result<Self> {
618 let src: Ident = input.parse()?;
619 let _: Token![->] = input.parse()?;
620 let dst: Ident = input.parse()?;
621 let content;
622 syn::braced!(content in input);
623 let mut arms = Vec::new();
624 while !content.is_empty() {
625 let variant: Ident = content.parse()?;
626 let inner_content;
627 syn::parenthesized!(inner_content in content);
628 let inner: Ident = inner_content.parse()?;
629 arms.push((variant, inner));
630 if !content.is_empty() {
631 let _: Token![,] = content.parse()?;
632 }
633 }
634 Ok(BridgeInput { src, dst, arms })
635 }
636}
637
638#[proc_macro]
639pub fn bridge(input: TokenStream) -> TokenStream {
640 let BridgeInput { src, dst, arms } = parse_macro_input!(input as BridgeInput);
641 let variant_idents: Vec<&Ident> = arms.iter().map(|(v, _)| v).collect();
642 let inner_types: Vec<&Ident> = arms.iter().map(|(_, t)| t).collect();
643
644 let expanded = quote! {
645 impl ::std::convert::From<#src> for #dst {
646 fn from(val: #src) -> Self {
647 match val {
648 #( #src::#variant_idents(id) => <#dst as ::std::convert::From<#inner_types>>::from(id), )*
649 }
650 }
651 }
652 };
653 expanded.into()
654}