Skip to main content

tightbeam_derive/
lib.rs

1//! Derive macro for TightBeam message types
2//!
3//! This crate provides the `#[derive(Beamable)]` macro that automatically
4//! implements the `Message` trait for structs.
5
6mod build;
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::parse::Parser;
11use syn::punctuated::Punctuated;
12use syn::{parse_macro_input, Attribute, DeriveInput, Meta, Token};
13
14fn has_flag(attrs: &[Attribute], name: &str) -> bool {
15	for attr in attrs {
16		if !attr.path().is_ident("beam") {
17			continue;
18		}
19		if let Meta::List(list) = &attr.meta {
20			// Allow mixing identifiers and name-value pairs in #[beam(...)]
21			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
22			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
23				for meta in metas {
24					if let Meta::Path(path) = meta {
25						if path.is_ident(name) {
26							return true;
27						}
28					}
29				}
30			}
31		}
32	}
33	false
34}
35
36fn get_version_value(attrs: &[Attribute]) -> Option<syn::Ident> {
37	for attr in attrs {
38		if !attr.path().is_ident("beam") {
39			continue;
40		}
41		if let Meta::List(list) = &attr.meta {
42			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
43			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
44				for meta in metas {
45					if let Meta::NameValue(nv) = meta {
46						if nv.path.is_ident("min_version") {
47							if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Str(lit_str), .. }) = &nv.value {
48								return Some(syn::Ident::new(&lit_str.value(), lit_str.span()));
49							}
50						}
51					}
52				}
53			}
54		}
55	}
56	None
57}
58
59fn get_profile_value(attrs: &[Attribute]) -> Option<u8> {
60	for attr in attrs {
61		if !attr.path().is_ident("beam") {
62			continue;
63		}
64		if let Meta::List(list) = &attr.meta {
65			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
66			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
67				for meta in metas {
68					if let Meta::NameValue(nv) = meta {
69						if nv.path.is_ident("profile") {
70							if let syn::Expr::Lit(syn::ExprLit { lit: syn::Lit::Int(lit_int), .. }) = &nv.value {
71								if let Ok(profile) = lit_int.base10_parse::<u8>() {
72									return Some(profile);
73								}
74							}
75						}
76					}
77				}
78			}
79		}
80	}
81	None
82}
83
84fn get_profile_type(attrs: &[Attribute]) -> Option<syn::Type> {
85	for attr in attrs {
86		if !attr.path().is_ident("beam") {
87			continue;
88		}
89		if let Meta::List(list) = &attr.meta {
90			let parser = Punctuated::<Meta, Token![,]>::parse_terminated;
91			if let Ok(metas) = parser.parse2(list.tokens.clone()) {
92				for meta in metas {
93					if let Meta::List(profile_list) = meta {
94						if profile_list.path.is_ident("profile") {
95							// Parse the content inside profile(...) as a type
96							if let Ok(ty) = syn::parse2::<syn::Type>(profile_list.tokens.clone()) {
97								return Some(ty);
98							}
99						}
100					}
101				}
102			}
103		}
104	}
105	None
106}
107
108fn has_attr(attrs: &[Attribute], name: &str) -> bool {
109	attrs.iter().any(|attr| attr.path().is_ident(name))
110}
111
112fn get_error_message(attrs: &[Attribute]) -> Option<String> {
113	for attr in attrs {
114		if attr.path().is_ident("error") {
115			if let Meta::List(list) = &attr.meta {
116				if let Ok(lit_str) = syn::parse2::<syn::LitStr>(list.tokens.clone()) {
117					return Some(lit_str.value());
118				}
119			}
120		}
121	}
122	None
123}
124
125/// Derive macro for implementing `Message`
126///
127/// This macro can be applied to any struct that implements the necessary
128/// serialization traits (typically `der::Sequence`).
129#[proc_macro_derive(Beamable, attributes(beam))]
130pub fn derive_beamable(input: TokenStream) -> TokenStream {
131	let input = parse_macro_input!(input as DeriveInput);
132	let name = &input.ident;
133
134	let confidential = has_flag(&input.attrs, "confidential");
135	let nonrep = has_flag(&input.attrs, "nonrepudiable");
136	let compressed = has_flag(&input.attrs, "compressed");
137	let prioritized = has_flag(&input.attrs, "prioritized");
138	let message_integrity = has_flag(&input.attrs, "message_integrity");
139	let frame_integrity = has_flag(&input.attrs, "frame_integrity");
140	let min_version = get_version_value(&input.attrs);
141	let profile_value = get_profile_value(&input.attrs);
142	let profile_type = get_profile_type(&input.attrs);
143
144	// Validate that we don't have both numeric and type-based profiles
145	if profile_value.is_some() && profile_type.is_some() {
146		return syn::Error::new_spanned(
147			&input,
148			"Cannot specify both numeric profile (= N) and type-based profile (Type) simultaneously",
149		)
150		.to_compile_error()
151		.into();
152	}
153
154	// Profile-based security requirements
155	let (profile_confidential, profile_nonrep, profile_min_version) = match profile_value {
156		Some(1) => (true, true, Some(syn::Ident::new("V1", name.span()))), // FIPS
157		Some(2) => (true, true, Some(syn::Ident::new("V1", name.span()))), // Standard
158		Some(p) if p > 2 => (false, false, None),
159		_ => (false, false, None),
160	};
161
162	// Apply profile requirements (override individual flags)
163	let final_confidential = profile_confidential || confidential;
164	let final_nonrep = profile_nonrep || nonrep;
165	let final_min_version = profile_min_version.or(min_version);
166	let final_message_integrity = message_integrity;
167	let final_frame_integrity = frame_integrity;
168
169	let mut feature_checks = Vec::new();
170
171	if final_confidential && !cfg!(feature = "aead") {
172		feature_checks.push(quote! {
173			compile_error!(concat!(
174				"Message type `", stringify!(#name), "` is marked as confidential ",
175				"but the `aead` feature is not enabled. ",
176				"Enable the feature in Cargo.toml: features = [\"aead\"]"
177			));
178		});
179	}
180
181	if final_nonrep && !cfg!(feature = "signature") {
182		feature_checks.push(quote! {
183			compile_error!(concat!(
184				"Message type `", stringify!(#name), "` is marked as non-repudiable ",
185				"but the `signature` feature is not enabled. ",
186				"Enable the feature in Cargo.toml: features = [\"signature\"]"
187			));
188		});
189	}
190
191	if compressed && !cfg!(feature = "compress") {
192		feature_checks.push(quote! {
193			compile_error!(concat!(
194				"Message type `", stringify!(#name), "` is marked as compressed ",
195				"but the `compress` feature is not enabled. ",
196				"Enable the feature in Cargo.toml: features = [\"compress\"]"
197			));
198		});
199	}
200
201	if (final_message_integrity || final_frame_integrity) && !cfg!(feature = "digest") {
202		feature_checks.push(quote! {
203			compile_error!(concat!(
204				"Message type `", stringify!(#name), "` is marked as requiring message integrity ",
205				"but the `digest` feature is not enabled. ",
206				"Enable the feature in Cargo.toml: features = [\"digest\"]"
207			));
208		});
209	}
210
211	let min_version_value = if let Some(version) = final_min_version {
212		quote! { ::tightbeam::Version::#version }
213	} else {
214		quote! { ::tightbeam::Version::V0 }
215	};
216
217	let _has_profile = profile_type.is_some();
218	let profile_type_impl = if let Some(profile_ty) = &profile_type {
219		quote! {
220			const HAS_PROFILE: bool = true;
221			type Profile = #profile_ty;
222		}
223	} else {
224		// Always define HAS_PROFILE, even when false (needed for checker trait impls)
225		quote! {
226			const HAS_PROFILE: bool = false;
227			type Profile = ::tightbeam::crypto::profiles::TightbeamProfile;
228		}
229	};
230
231	// Generate checker trait implementations for compile-time OID validation
232	// When HAS_PROFILE = true: generates impls ONLY for the matching OID type from the profile (compile-time enforcement)
233	// When HAS_PROFILE = false: generates generic impls for all OID types (no enforcement, allows any)
234	// All types using #[derive(Beamable)] get these impls - types not using derive must implement manually
235	let oid_validation_helpers = if let Some(profile_ty) = &profile_type {
236		// We know the profile type, so we can reference its associated types directly
237		// ONLY implement for the exact OID types from the profile - wrong OIDs will fail to compile
238		quote! {
239			#[cfg(feature = "digest")]
240			impl ::tightbeam::builder::private::SealedDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
241			where
242				#name: ::tightbeam::Message,
243			{}
244
245			#[cfg(feature = "digest")]
246			impl ::tightbeam::builder::CheckDigestOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::DigestOid> for #name
247			where
248				#name: ::tightbeam::Message,
249			{
250				const RESULT: () = ();
251			}
252
253			#[cfg(feature = "aead")]
254			impl ::tightbeam::builder::private::SealedAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
255			where
256				#name: ::tightbeam::Message,
257			{}
258
259			#[cfg(feature = "aead")]
260			impl ::tightbeam::builder::CheckAeadOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::AeadOid> for #name
261			where
262				#name: ::tightbeam::Message,
263			{
264				const RESULT: () = ();
265			}
266
267			#[cfg(feature = "signature")]
268			impl ::tightbeam::builder::private::SealedSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
269			where
270				#name: ::tightbeam::Message,
271			{}
272
273			#[cfg(feature = "signature")]
274			impl ::tightbeam::builder::CheckSignatureOid<<#profile_ty as ::tightbeam::crypto::profiles::SecurityProfile>::SignatureAlg> for #name
275			where
276				#name: ::tightbeam::Message,
277			{
278				const RESULT: () = ();
279			}
280		}
281	} else {
282		// When HAS_PROFILE = false, generate generic impls for all OID types (no enforcement)
283		// These allow FrameBuilder methods to work for types without profiles
284		quote! {
285			#[cfg(feature = "digest")]
286			impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedDigestOid<D> for #name
287			where
288				#name: ::tightbeam::Message,
289			{}
290
291			#[cfg(feature = "digest")]
292			impl<D: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckDigestOid<D> for #name
293			where
294				#name: ::tightbeam::Message,
295			{
296				const RESULT: () = ();
297			}
298
299			#[cfg(feature = "aead")]
300			impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::private::SealedAeadOid<C> for #name
301			where
302				#name: ::tightbeam::Message,
303			{}
304
305			#[cfg(feature = "aead")]
306			impl<C: ::tightbeam::der::oid::AssociatedOid> ::tightbeam::builder::CheckAeadOid<C> for #name
307			where
308				#name: ::tightbeam::Message,
309			{
310				const RESULT: () = ();
311			}
312
313			#[cfg(feature = "signature")]
314			impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::private::SealedSignatureOid<S> for #name
315			where
316				#name: ::tightbeam::Message,
317			{}
318
319			#[cfg(feature = "signature")]
320			impl<S: ::tightbeam::crypto::sign::SignatureAlgorithmIdentifier> ::tightbeam::builder::CheckSignatureOid<S> for #name
321			where
322				#name: ::tightbeam::Message,
323			{
324				const RESULT: () = ();
325			}
326		}
327	};
328
329	let expanded = quote! {
330		const _: () = {
331			#(#feature_checks)*
332		};
333
334		impl ::tightbeam::Message for #name {
335			const MUST_BE_CONFIDENTIAL: bool = #final_confidential;
336			const MUST_BE_NON_REPUDIABLE: bool = #final_nonrep;
337			const MUST_HAVE_MESSAGE_INTEGRITY: bool = #final_message_integrity;
338			const MUST_HAVE_FRAME_INTEGRITY: bool = #final_frame_integrity;
339			const MUST_BE_COMPRESSED: bool = #compressed;
340			const MUST_BE_PRIORITIZED: bool = #prioritized;
341			const MIN_VERSION: ::tightbeam::Version = #min_version_value;
342			#profile_type_impl
343		}
344
345		#oid_validation_helpers
346	};
347
348	TokenStream::from(expanded)
349}
350
351/// Derive macro for implementing flag enum traits
352///
353/// This macro automatically adds the necessary attributes and trait
354/// implementations for flag enums used with the TightBeam flag system.
355#[proc_macro_derive(Flaggable)]
356pub fn derive_flaggable(input: TokenStream) -> TokenStream {
357	let input = parse_macro_input!(input as DeriveInput);
358	let name = &input.ident;
359	let name_str = name.to_string();
360
361	let expanded = quote! {
362		impl From<#name> for u8 {
363			fn from(val: #name) -> u8 {
364				val as u8
365			}
366		}
367
368		impl PartialEq<u8> for #name {
369			fn eq(&self, other: &u8) -> bool {
370				(*self as u8) == *other
371			}
372		}
373
374		impl #name {
375			pub const TYPE_NAME: &'static str = #name_str;
376		}
377	};
378
379	TokenStream::from(expanded)
380}
381
382/// Derive macro for implementing error traits with automatic Display and From
383/// implementations
384///
385/// This macro automatically implements `Display`, `Error`, and `From`
386/// conversions for error enums, similar to the `snafu` crate.
387///
388/// # Attributes
389///
390/// - `#[error("format string")]` - Specifies the display format for the variant
391/// - `#[from]` - Automatically implements `From` for the wrapped type
392#[proc_macro_derive(Errorizable, attributes(error, from))]
393pub fn derive_errorizable(input: TokenStream) -> TokenStream {
394	let input = parse_macro_input!(input as DeriveInput);
395	let name = &input.ident;
396
397	let data_enum = match &input.data {
398		syn::Data::Enum(data) => data,
399		_ => {
400			return syn::Error::new_spanned(&input, "Errorizable can only be derived for enums")
401				.to_compile_error()
402				.into();
403		}
404	};
405
406	let mut display_arms = Vec::new();
407	let mut from_impls = Vec::new();
408
409	for variant in &data_enum.variants {
410		let variant_name = &variant.ident;
411
412		// Get the error message from #[error("...")] attribute
413		let error_msg = get_error_message(&variant.attrs);
414		let has_from = has_attr(&variant.attrs, "from");
415
416		// Build the display match arm based on variant fields
417		match &variant.fields {
418			syn::Fields::Unnamed(fields) => {
419				let field_count = fields.unnamed.len();
420				let field_bindings: Vec<_> = (0..field_count)
421					.map(|i| syn::Ident::new(&format!("f{i}"), variant_name.span()))
422					.collect();
423
424				if let Some(msg) = error_msg {
425					// Check if format string contains field accessors like {expected} or {received}
426					if msg.contains("{expected") || msg.contains("{received") {
427						// Assume single field with .expected and .received properties
428						display_arms.push(quote! {
429							#name::#variant_name(ref f0) => {
430								write!(f, #msg, expected = f0.expected, received = f0.received)
431							}
432						});
433					} else {
434						display_arms.push(quote! {
435							#name::#variant_name(#(ref #field_bindings),*) => {
436								write!(f, #msg, #(#field_bindings),*)
437							}
438						});
439					}
440				} else {
441					display_arms.push(quote! {
442						#name::#variant_name(#(ref #field_bindings),*) => {
443							write!(f, "{}", stringify!(#variant_name))
444						}
445					});
446				}
447
448				// Generate From impl if #[from] is present and there's exactly one field
449				if has_from && field_count == 1 {
450					let field_type = &fields.unnamed.first().unwrap().ty;
451					from_impls.push(quote! {
452						impl From<#field_type> for #name {
453							fn from(err: #field_type) -> Self {
454								#name::#variant_name(err)
455							}
456						}
457
458						impl From<#name> for #field_type {
459							fn from(err: #name) -> Self {
460								match err {
461									#name::#variant_name(inner) => inner,
462									_ => panic!("Cannot convert {} to {}", stringify!(#name), stringify!(#field_type)),
463								}
464							}
465						}
466					});
467				}
468			}
469			syn::Fields::Named(fields) => {
470				let field_names: Vec<_> = fields.named.iter().map(|f| &f.ident).collect();
471
472				if let Some(msg) = error_msg {
473					display_arms.push(quote! {
474						#name::#variant_name { #(ref #field_names),* } => {
475							write!(f, #msg, #(#field_names = #field_names),*)
476						}
477					});
478				} else {
479					display_arms.push(quote! {
480						#name::#variant_name { .. } => {
481							write!(f, "{}", stringify!(#variant_name))
482						}
483					});
484				}
485			}
486			syn::Fields::Unit => {
487				if let Some(msg) = error_msg {
488					display_arms.push(quote! {
489						#name::#variant_name => write!(f, #msg)
490					});
491				} else {
492					display_arms.push(quote! {
493						#name::#variant_name => write!(f, "{}", stringify!(#variant_name))
494					});
495				}
496			}
497		}
498	}
499
500	let expanded = quote! {
501		impl core::fmt::Display for #name {
502			fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
503				match self {
504					#(#display_arms,)*
505				}
506			}
507		}
508
509		impl core::error::Error for #name {}
510
511		#(#from_impls)*
512	};
513
514	TokenStream::from(expanded)
515}
516
517/// Generate all configured builder macros
518#[proc_macro]
519pub fn generate_builders(_input: TokenStream) -> TokenStream {
520	let macros: Vec<_> = build::BUILDER_CONFIGS.iter().map(build::generate_builder_macro).collect();
521
522	let output = quote! {
523		#(#macros)*
524	};
525
526	output.into()
527}