1mod build;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::parse::Parser;
11use syn::punctuated::Punctuated;
12use syn::{parse_macro_input, Attribute, DeriveInput, Meta, Token};
13
14fn has_flag(attrs: &[Attribute], name: &str) -> bool {
15 for attr in attrs {
16 if !attr.path().is_ident("beam") {
17 continue;
18 }
19 if let Meta::List(list) = &attr.meta {
20 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
22 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
23 for meta in metas {
24 if let Meta::Path(path) = meta {
25 if path.is_ident(name) {
26 return true;
27 }
28 }
29 }
30 }
31 }
32 }
33 false
34}
35
36fn get_version_value(attrs: &[Attribute]) -> Option<syn::Ident> {
37 for attr in attrs {
38 if !attr.path().is_ident("beam") {
39 continue;
40 }
41 if let Meta::List(list) = &attr.meta {
42 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
43 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
44 for meta in metas {
45 if let Meta::NameValue(nv) = meta {
46 if nv.path.is_ident("min_version") {
47 if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. }) = &nv.value {
48 return Some(syn::Ident::new(&lit_str.value(), lit_str.span()));
49 }
50 }
51 }
52 }
53 }
54 }
55 }
56 None
57}
58
59fn get_profile_value(attrs: &[Attribute]) -> Option<u8> {
60 for attr in attrs {
61 if !attr.path().is_ident("beam") {
62 continue;
63 }
64 if let Meta::List(list) = &attr.meta {
65 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
66 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
67 for meta in metas {
68 if let Meta::NameValue(nv) = meta {
69 if nv.path.is_ident("profile") {
70 if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) = &nv.value {
71 if let Ok(profile) = lit_int.base10_parse::<u8>() {
72 return Some(profile);
73 }
74 }
75 }
76 }
77 }
78 }
79 }
80 }
81 None
82}
83
84fn get_profile_type(attrs: &[Attribute]) -> Option<syn::Type> {
85 for attr in attrs {
86 if !attr.path().is_ident("beam") {
87 continue;
88 }
89 if let Meta::List(list) = &attr.meta {
90 let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
91 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
92 for meta in metas {
93 if let Meta::List(profile_list) = meta {
94 if profile_list.path.is_ident("profile") {
95 if let Ok(ty) = syn::parse2::<syn::Type>(profile_list.tokens.clone()) {
97 return Some(ty);
98 }
99 }
100 }
101 }
102 }
103 }
104 }
105 None
106}
107
108fn has_attr(attrs: &[Attribute], name: &str) -> bool {
109 attrs.iter().any(|attr| attr.path().is_ident(name))
110}
111
112fn get_error_message(attrs: &[Attribute]) -> Option<String> {
113 for attr in attrs {
114 if attr.path().is_ident("error") {
115 if let Meta::List(list) = &attr.meta {
116 if let Ok(lit_str) = syn::parse2::<syn::LitStr>(list.tokens.clone()) {
117 return Some(lit_str.value());
118 }
119 }
120 }
121 }
122 None
123}
124
125#[proc_macro_derive(Beamable, attributes(beam))]
130pub fn derive_beamable(input: TokenStream) -> TokenStream {
131 let input = parse_macro_input!(input as DeriveInput);
132 let name = &input.ident;
133
134 let confidential = has_flag(&input.attrs, "confidential");
135 let nonrep = has_flag(&input.attrs, "nonrepudiable");
136 let compressed = has_flag(&input.attrs, "compressed");
137 let prioritized = has_flag(&input.attrs, "prioritized");
138 let message_integrity = has_flag(&input.attrs, "message_integrity");
139 let frame_integrity = has_flag(&input.attrs, "frame_integrity");
140 let min_version = get_version_value(&input.attrs);
141 let profile_value = get_profile_value(&input.attrs);
142 let profile_type = get_profile_type(&input.attrs);
143
144 if profile_value.is_some() && profile_type.is_some() {
146 return syn::Error::new_spanned(
147 &input,
148 "Cannot specify both numeric profile (= N) and type-based profile (Type) simultaneously",
149 )
150 .to_compile_error()
151 .into();
152 }
153
154 let (profile_confidential, profile_nonrep, profile_min_version) = match profile_value {
156 Some(1) => (true, true, Some(syn::Ident::new("V1", name.span()))), Some(2) => (true, true, Some(syn::Ident::new("V1", name.span()))), Some(p) if p > 2 => (false, false, None),
159 _ => (false, false, None),
160 };
161
162 let final_confidential = profile_confidential || confidential;
164 let final_nonrep = profile_nonrep || nonrep;
165 let final_min_version = profile_min_version.or(min_version);
166 let final_message_integrity = message_integrity;
167 let final_frame_integrity = frame_integrity;
168
169 let mut feature_checks = Vec::new();
170
171 if final_confidential && !cfg!(feature = "aead") {
172 feature_checks.push(quote! {
173 compile_error!(concat!(
174 "Message type `", stringify!(#name), "` is marked as confidential ",
175 "but the `aead` feature is not enabled. ",
176 "Enable the feature in Cargo.toml: features = [\"aead\"]"
177 ));
178 });
179 }
180
181 if final_nonrep && !cfg!(feature = "signature") {
182 feature_checks.push(quote! {
183 compile_error!(concat!(
184 "Message type `", stringify!(#name), "` is marked as non-repudiable ",
185 "but the `signature` feature is not enabled. ",
186 "Enable the feature in Cargo.toml: features = [\"signature\"]"
187 ));
188 });
189 }
190
191 if compressed && !cfg!(feature = "compress") {
192 feature_checks.push(quote! {
193 compile_error!(concat!(
194 "Message type `", stringify!(#name), "` is marked as compressed ",
195 "but the `compress` feature is not enabled. ",
196 "Enable the feature in Cargo.toml: features = [\"compress\"]"
197 ));
198 });
199 }
200
201 if (final_message_integrity || final_frame_integrity) && !cfg!(feature = "digest") {
202 feature_checks.push(quote! {
203 compile_error!(concat!(
204 "Message type `", stringify!(#name), "` is marked as requiring message integrity ",
205 "but the `digest` feature is not enabled. ",
206 "Enable the feature in Cargo.toml: features = [\"digest\"]"
207 ));
208 });
209 }
210
211 let min_version_value = if let Some(version) = final_min_version {
212 quote! { ::tightbeam::Version::#version }
213 } else {
214 quote! { ::tightbeam::Version::V0 }
215 };
216
217 let _has_profile = profile_type.is_some();
218 let profile_type_impl = if let Some(profile_ty) = &profile_type {
219 quote! {
220 const HAS_PROFILE: bool = true;
221 type Profile = #profile_ty;
222 }
223 } else {
224 quote! {
226 const HAS_PROFILE: bool = false;
227 type Profile = ::tightbeam::crypto::profiles::TightbeamProfile;
228 }
229 };
230
231 let oid_validation_helpers = if let Some(profile_ty) = &profile_type {
236 quote! {
239 #[cfg(feature = "digest")]
240 impl ::tightbeam::builder::private::SealedDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
241 where
242 #name: ::tightbeam::Message,
243 {}
244
245 #[cfg(feature = "digest")]
246 impl ::tightbeam::builder::CheckDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
247 where
248 #name: ::tightbeam::Message,
249 {
250 const RESULT: () = ();
251 }
252
253 #[cfg(feature = "aead")]
254 impl ::tightbeam::builder::private::SealedAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
255 where
256 #name: ::tightbeam::Message,
257 {}
258
259 #[cfg(feature = "aead")]
260 impl ::tightbeam::builder::CheckAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
261 where
262 #name: ::tightbeam::Message,
263 {
264 const RESULT: () = ();
265 }
266
267 #[cfg(feature = "signature")]
268 impl ::tightbeam::builder::private::SealedSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
269 where
270 #name: ::tightbeam::Message,
271 {}
272
273 #[cfg(feature = "signature")]
274 impl ::tightbeam::builder::CheckSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
275 where
276 #name: ::tightbeam::Message,
277 {
278 const RESULT: () = ();
279 }
280 }
281 } else {
282 quote! {
285 #[cfg(feature = "digest")]
286 impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedDigestOid<D> for #name
287 where
288 #name: ::tightbeam::Message,
289 {}
290
291 #[cfg(feature = "digest")]
292 impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckDigestOid<D> for #name
293 where
294 #name: ::tightbeam::Message,
295 {
296 const RESULT: () = ();
297 }
298
299 #[cfg(feature = "aead")]
300 impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedAeadOid<C> for #name
301 where
302 #name: ::tightbeam::Message,
303 {}
304
305 #[cfg(feature = "aead")]
306 impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckAeadOid<C> for #name
307 where
308 #name: ::tightbeam::Message,
309 {
310 const RESULT: () = ();
311 }
312
313 #[cfg(feature = "signature")]
314 impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::private::SealedSignatureOid<S> for #name
315 where
316 #name: ::tightbeam::Message,
317 {}
318
319 #[cfg(feature = "signature")]
320 impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::CheckSignatureOid<S> for #name
321 where
322 #name: ::tightbeam::Message,
323 {
324 const RESULT: () = ();
325 }
326 }
327 };
328
329 let expanded = quote! {
330 const _: () = {
331 #(#feature_checks)*
332 };
333
334 impl ::tightbeam::Message for #name {
335 const MUST_BE_CONFIDENTIAL: bool = #final_confidential;
336 const MUST_BE_NON_REPUDIABLE: bool = #final_nonrep;
337 const MUST_HAVE_MESSAGE_INTEGRITY: bool = #final_message_integrity;
338 const MUST_HAVE_FRAME_INTEGRITY: bool = #final_frame_integrity;
339 const MUST_BE_COMPRESSED: bool = #compressed;
340 const MUST_BE_PRIORITIZED: bool = #prioritized;
341 const MIN_VERSION: ::tightbeam::Version = #min_version_value;
342 #profile_type_impl
343 }
344
345 #oid_validation_helpers
346 };
347
348 TokenStream::from(expanded)
349}
350
351#[proc_macro_derive(Flaggable)]
356pub fn derive_flaggable(input: TokenStream) -> TokenStream {
357 let input = parse_macro_input!(input as DeriveInput);
358 let name = &input.ident;
359 let name_str = name.to_string();
360
361 let expanded = quote! {
362 impl From<#name> for u8 {
363 fn from(val: #name) -> u8 {
364 val as u8
365 }
366 }
367
368 impl PartialEq<u8> for #name {
369 fn eq(&self, other: &u8) -> bool {
370 (*self as u8) == *other
371 }
372 }
373
374 impl #name {
375 pub const TYPE_NAME: &'static str = #name_str;
376 }
377 };
378
379 TokenStream::from(expanded)
380}
381
382#[proc_macro_derive(Errorizable, attributes(error, from))]
393pub fn derive_errorizable(input: TokenStream) -> TokenStream {
394 let input = parse_macro_input!(input as DeriveInput);
395 let name = &input.ident;
396
397 let data_enum = match &input.data {
398 syn::Data::Enum(data) => data,
399 _ => {
400 return syn::Error::new_spanned(&input, "Errorizable can only be derived for enums")
401 .to_compile_error()
402 .into();
403 }
404 };
405
406 let mut display_arms = Vec::new();
407 let mut from_impls = Vec::new();
408
409 for variant in &data_enum.variants {
410 let variant_name = &variant.ident;
411
412 let error_msg = get_error_message(&variant.attrs);
414 let has_from = has_attr(&variant.attrs, "from");
415
416 match &variant.fields {
418 syn::Fields::Unnamed(fields) => {
419 let field_count = fields.unnamed.len();
420 let field_bindings: Vec<_> = (0..field_count)
421 .map(|i| syn::Ident::new(&format!("f{i}"), variant_name.span()))
422 .collect();
423
424 if let Some(msg) = error_msg {
425 if msg.contains("{expected") || msg.contains("{received") {
427 display_arms.push(quote! {
429 #name::#variant_name(ref f0) => {
430 write!(f, #msg, expected = f0.expected, received = f0.received)
431 }
432 });
433 } else {
434 display_arms.push(quote! {
435 #name::#variant_name(#(ref #field_bindings),*) => {
436 write!(f, #msg, #(#field_bindings),*)
437 }
438 });
439 }
440 } else {
441 display_arms.push(quote! {
442 #name::#variant_name(#(ref #field_bindings),*) => {
443 write!(f, "{}", stringify!(#variant_name))
444 }
445 });
446 }
447
448 if has_from && field_count == 1 {
450 let field_type = &fields.unnamed.first().unwrap().ty;
451 from_impls.push(quote! {
452 impl From<#field_type> for #name {
453 fn from(err: #field_type) -> Self {
454 #name::#variant_name(err)
455 }
456 }
457
458 impl From<#name> for #field_type {
459 fn from(err: #name) -> Self {
460 match err {
461 #name::#variant_name(inner) => inner,
462 _ => panic!("Cannot convert {} to {}", stringify!(#name), stringify!(#field_type)),
463 }
464 }
465 }
466 });
467 }
468 }
469 syn::Fields::Named(fields) => {
470 let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
471
472 if let Some(msg) = error_msg {
473 display_arms.push(quote! {
474 #name::#variant_name { #(ref #field_names),* } => {
475 write!(f, #msg, #(#field_names = #field_names),*)
476 }
477 });
478 } else {
479 display_arms.push(quote! {
480 #name::#variant_name { .. } => {
481 write!(f, "{}", stringify!(#variant_name))
482 }
483 });
484 }
485 }
486 syn::Fields::Unit => {
487 if let Some(msg) = error_msg {
488 display_arms.push(quote! {
489 #name::#variant_name => write!(f, #msg)
490 });
491 } else {
492 display_arms.push(quote! {
493 #name::#variant_name => write!(f, "{}", stringify!(#variant_name))
494 });
495 }
496 }
497 }
498 }
499
500 let expanded = quote! {
501 impl core::fmt::Display for #name {
502 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
503 match self {
504 #(#display_arms,)*
505 }
506 }
507 }
508
509 impl core::error::Error for #name {}
510
511 #(#from_impls)*
512 };
513
514 TokenStream::from(expanded)
515}
516
517#[proc_macro]
519pub fn generate_builders(_input: TokenStream) -> TokenStream {
520 let macros: Vec<_> = build::BUILDER_CONFIGS.iter().map(build::generate_builder_macro).collect();
521
522 let output = quote! {
523 #(#macros)*
524 };
525
526 output.into()
527}