Skip to main content

sourcery_macros/
lib.rs

1// These lints are triggered by darling's generated code for
2// `#[darling(default)]`.
3#![allow(clippy::option_if_let_else)]
4#![allow(clippy::needless_continue)]
5
6use darling::{FromDeriveInput, FromMeta, util::PathList};
7use heck::{ToKebabCase, ToUpperCamelCase};
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{ToTokens, quote};
11use syn::{DeriveInput, Ident, Path, parse_macro_input, parse_quote};
12
13#[allow(clippy::doc_markdown, reason = "false positive")]
14/// Build a PascalCase enum variant name from a type path.
15fn path_to_pascal_ident(path: &Path) -> Ident {
16    let mut combined = String::new();
17    for (index, segment) in path.segments.iter().enumerate() {
18        if index > 0 {
19            combined.push('_');
20        }
21        combined.push_str(&segment.ident.to_string());
22    }
23    let pascal = combined.to_upper_camel_case();
24    let span = path
25        .segments
26        .last()
27        .map_or_else(proc_macro2::Span::call_site, |segment| segment.ident.span());
28    Ident::new(&pascal, span)
29}
30
31/// Parse `key = Type` meta items into a `syn::Type`.
32fn parse_name_value_type(item: &syn::Meta) -> darling::Result<syn::Type> {
33    let error = || darling::Error::unsupported_shape("expected `key = Type`");
34    let syn::Meta::NameValue(nv) = item else {
35        return Err(error());
36    };
37    syn::parse2(nv.value.to_token_stream()).map_err(|_| error())
38}
39
40/// Returns the kind override or the default kebab-case name from the ident.
41fn default_kind(ident: &Ident, kind: Option<String>) -> String {
42    kind.unwrap_or_else(|| ident.to_string().to_kebab_case())
43}
44
45/// Wrapper for `syn::Path` that parses from `key = Type` syntax.
46#[derive(Debug, Clone)]
47struct TypePath(Path);
48
49impl FromMeta for TypePath {
50    fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
51        let ty = parse_name_value_type(item)?;
52        match ty {
53            syn::Type::Path(type_path) if type_path.qself.is_none() => Ok(Self(type_path.path)),
54            _ => Err(darling::Error::unsupported_shape("expected `key = Type`")),
55        }
56    }
57}
58
59/// Wrapper for `syn::Type` that parses from `key = Type` syntax.
60#[derive(Debug, Clone)]
61struct TypeExpr(syn::Type);
62
63impl FromMeta for TypeExpr {
64    fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
65        parse_name_value_type(item).map(Self)
66    }
67}
68
69/// Configuration for the `#[aggregate(...)]` attribute.
70#[derive(Debug, FromDeriveInput)]
71#[darling(attributes(aggregate), supports(struct_any))]
72struct AggregateArgs {
73    ident: Ident,
74    vis: syn::Visibility,
75    id: TypePath,
76    error: TypePath,
77    events: PathList,
78    #[darling(default)]
79    kind: Option<String>,
80    #[darling(default)]
81    event_enum: Option<String>,
82    #[darling(default)]
83    derives: Option<PathList>,
84}
85
86/// Configuration for the `#[projection(...)]` attribute.
87#[derive(Debug, FromDeriveInput)]
88#[darling(attributes(projection), supports(struct_any))]
89struct ProjectionArgs {
90    ident: Ident,
91    #[darling(default)]
92    kind: Option<String>,
93    #[darling(default)]
94    id: Option<TypeExpr>,
95    #[darling(default)]
96    instance_id: Option<TypeExpr>,
97    #[darling(default)]
98    metadata: Option<TypeExpr>,
99    #[darling(default)]
100    events: PathList,
101}
102
103/// Captures the event type path and its generated enum variant identifier.
104struct EventSpec<'a> {
105    path: &'a Path,
106    variant: Ident,
107}
108
109impl<'a> EventSpec<'a> {
110    /// Build an event spec from a type path.
111    fn new(path: &'a Path) -> Self {
112        Self {
113            path,
114            variant: path_to_pascal_ident(path),
115        }
116    }
117}
118
119/// Parse derive input with darling and render errors as tokens.
120fn parse_or_error<T, F>(input: &DeriveInput, f: F) -> TokenStream2
121where
122    T: FromDeriveInput,
123    F: FnOnce(T) -> TokenStream2,
124{
125    match T::from_derive_input(input) {
126        Ok(args) => f(args),
127        Err(err) => err.write_errors(),
128    }
129}
130
131/// Derives the `Aggregate` trait for a struct.
132///
133/// This macro generates:
134/// - An event enum containing all aggregate event types
135/// - `EventKind` trait implementation for runtime kind dispatch
136/// - `ProjectionEvent` trait implementation for event deserialisation
137/// - `From<E>` implementations for each event type
138/// - `Aggregate` trait implementation that dispatches to `Apply<E>` for events
139///
140/// **Note:** Commands are handled via individual `Handle<C>` trait
141/// implementations. No command enum is generated - use
142/// `execute_command::<Aggregate, Command>()` directly.
143///
144/// # Attributes
145///
146/// ## Required
147/// - `id = Type` - Aggregate ID type
148/// - `error = Type` - Error type for command handling
149/// - `events(Type1, Type2, ...)` - Event types
150///
151/// ## Optional
152/// - `kind = "name"` - Aggregate type identifier (default: lowercase struct
153///   name)
154/// - `event_enum = "Name"` - Override generated event enum name (default:
155///   `{Struct}Event`)
156/// - `derives(Trait1, Trait2, ...)` - Additional derives for the generated
157///   event enum. Always includes `Clone` and `serde::Serialize`. Common
158///   additions: `Debug`, `PartialEq`, `Eq`
159///
160/// # Example
161///
162/// ```ignore
163/// #[derive(Aggregate)]
164/// #[aggregate(
165///     id = String,
166///     error = String,
167///     events(FundsDeposited, FundsWithdrawn),
168///     derives(Debug, PartialEq, Eq)
169/// )]
170/// pub struct Account {
171///     balance: i64,
172/// }
173/// ```
174#[proc_macro_derive(Aggregate, attributes(aggregate))]
175pub fn derive_aggregate(input: TokenStream) -> TokenStream {
176    let input = parse_macro_input!(input as DeriveInput);
177
178    derive_aggregate_impl(&input).into()
179}
180
181/// Internal entry point that returns tokens for the aggregate derive.
182fn derive_aggregate_impl(input: &DeriveInput) -> TokenStream2 {
183    parse_or_error::<AggregateArgs, _>(input, |args| generate_aggregate_impl(args, input))
184}
185
186/// Generate the aggregate derive implementation tokens.
187fn generate_aggregate_impl(args: AggregateArgs, input: &DeriveInput) -> TokenStream2 {
188    let event_specs: Vec<EventSpec<'_>> = args.events.iter().map(EventSpec::new).collect();
189
190    if event_specs.is_empty() {
191        return darling::Error::custom("events(...) must contain at least one event type")
192            .with_span(&input.ident)
193            .write_errors();
194    }
195
196    let struct_name = &args.ident;
197    let struct_vis = &args.vis;
198    let id_type = &args.id.0;
199    let error_type = &args.error.0;
200    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
201
202    let kind = default_kind(struct_name, args.kind);
203
204    let event_enum_name = args.event_enum.map_or_else(
205        || Ident::new(&format!("{struct_name}Event"), struct_name.span()),
206        |name| Ident::new(&name, struct_name.span()),
207    );
208
209    let event_types: Vec<&Path> = event_specs.iter().map(|spec| spec.path).collect();
210    let variant_names: Vec<&Ident> = event_specs.iter().map(|spec| &spec.variant).collect();
211
212    // Build derives list - always include Clone, add user-specified traits
213    let derives = if let Some(user_derives) = &args.derives {
214        let user_paths: Vec<&Path> = user_derives.iter().collect();
215        quote! { #[derive(Clone, #(#user_paths),*)] }
216    } else {
217        quote! { #[derive(Clone)] }
218    };
219
220    let expanded = quote! {
221        #derives
222        #struct_vis enum #event_enum_name {
223            #(#variant_names(#event_types)),*
224        }
225
226        impl ::sourcery::event::EventKind for #event_enum_name {
227            fn kind(&self) -> &'static str {
228                match self {
229                    #(Self::#variant_names(_) => #event_types::KIND),*
230                }
231            }
232        }
233
234        impl ::serde::Serialize for #event_enum_name {
235            fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
236            where
237                S: ::serde::Serializer,
238            {
239                match self {
240                    #(Self::#variant_names(inner) => ::serde::Serialize::serialize(inner, serializer)),*
241                }
242            }
243        }
244
245        impl ::sourcery::ProjectionEvent for #event_enum_name {
246            const EVENT_KINDS: &'static [&'static str] = &[#(#event_types::KIND),*];
247
248            fn from_stored<S: ::sourcery::store::EventStore>(
249                stored: &::sourcery::store::StoredEvent<S::Id, S::Position, S::Data, S::Metadata>,
250                store: &S,
251            ) -> Result<Self, ::sourcery::event::EventDecodeError<S::Error>> {
252                match stored.kind() {
253                    #(#event_types::KIND => Ok(Self::#variant_names(
254                        store.decode_event(stored).map_err(::sourcery::event::EventDecodeError::Store)?
255                    )),)*
256                    _ => Err(::sourcery::event::EventDecodeError::UnknownKind {
257                        kind: stored.kind().to_string(),
258                        expected: Self::EVENT_KINDS,
259                    }),
260                }
261            }
262        }
263
264        #(
265            impl From<#event_types> for #event_enum_name {
266                fn from(event: #event_types) -> Self {
267                    Self::#variant_names(event)
268                }
269            }
270        )*
271
272        impl #impl_generics ::sourcery::Aggregate for #struct_name #ty_generics #where_clause {
273            const KIND: &'static str = #kind;
274            type Event = #event_enum_name;
275            type Error = #error_type;
276            type Id = #id_type;
277
278            fn apply(&mut self, event: &Self::Event) {
279                match event {
280                    #(#event_enum_name::#variant_names(e) => ::sourcery::Apply::apply(self, e)),*
281                }
282            }
283        }
284    };
285
286    expanded
287}
288
289/// Derives the `Projection` trait for a struct.
290///
291/// This macro always generates:
292/// - `Projection` trait implementation with `KIND` constant
293///
294/// It can also generate [`ProjectionFilters`] for the common case via
295/// `events(...)` (or explicit `id` / `instance_id` / `metadata`
296/// attributes), using `Default` for initialisation and a simple
297/// `Filters::new().event::<E>()...` filter set.
298///
299/// # Attributes
300///
301/// ## Optional
302/// - `kind = "name"` - Projection type identifier (default: kebab-case struct
303///   name)
304/// - `events(Event1, Event2, ...)` - Auto-generate `ProjectionFilters::filters`
305///   with global event subscriptions.
306/// - `id = Type` - Override `ProjectionFilters::Id` (default: `String` when
307///   auto-generating `ProjectionFilters`)
308/// - `instance_id = Type` - Override `ProjectionFilters::InstanceId` (default:
309///   `()` when auto-generating `ProjectionFilters`)
310/// - `metadata = Type` - Override `ProjectionFilters::Metadata` (default: `()`
311///   when auto-generating `ProjectionFilters`)
312///
313/// # Example
314///
315/// ```ignore
316/// #[derive(Default, Projection)]
317/// pub struct AccountLedger {
318///     total: i64,
319/// }
320/// ```
321#[proc_macro_derive(Projection, attributes(projection))]
322pub fn derive_projection(input: TokenStream) -> TokenStream {
323    let input = parse_macro_input!(input as DeriveInput);
324
325    derive_projection_impl(&input).into()
326}
327
328/// Internal entry point that returns tokens for the projection derive.
329fn derive_projection_impl(input: &DeriveInput) -> TokenStream2 {
330    parse_or_error::<ProjectionArgs, _>(input, |args| generate_projection_impl(args, input))
331}
332
333/// Generate the projection derive implementation tokens.
334fn generate_projection_impl(args: ProjectionArgs, input: &DeriveInput) -> TokenStream2 {
335    let struct_name = &args.ident;
336    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
337    let kind = default_kind(struct_name, args.kind);
338
339    let projection_impl = quote! {
340        impl #impl_generics ::sourcery::Projection for #struct_name #ty_generics #where_clause {
341            const KIND: &'static str = #kind;
342        }
343    };
344
345    let auto_projection_filters = !args.events.is_empty()
346        || args.id.is_some()
347        || args.instance_id.is_some()
348        || args.metadata.is_some();
349
350    if !auto_projection_filters {
351        return projection_impl;
352    }
353
354    let id_ty = args.id.map_or_else(|| parse_quote!(String), |ty| ty.0);
355    let instance_id_ty = args.instance_id.map_or_else(|| parse_quote!(()), |ty| ty.0);
356    let metadata_ty = args.metadata.map_or_else(|| parse_quote!(()), |ty| ty.0);
357    let subscribed_events: Vec<&Path> = args.events.iter().collect();
358
359    let filters_body = if subscribed_events.is_empty() {
360        quote! { ::sourcery::Filters::new() }
361    } else {
362        quote! { ::sourcery::Filters::new() #(.event::<#subscribed_events>())* }
363    };
364
365    let projection_filters_where = if let Some(where_clause) = where_clause {
366        let predicates = &where_clause.predicates;
367        quote! {
368            where
369                #predicates,
370                #struct_name #ty_generics: ::core::default::Default
371        }
372    } else {
373        quote! {
374            where
375                #struct_name #ty_generics: ::core::default::Default
376        }
377    };
378
379    quote! {
380        #projection_impl
381
382        impl #impl_generics ::sourcery::ProjectionFilters for #struct_name #ty_generics
383        #projection_filters_where
384        {
385            type Id = #id_ty;
386            type InstanceId = #instance_id_ty;
387            type Metadata = #metadata_ty;
388
389            fn init(_instance_id: &Self::InstanceId) -> Self {
390                Self::default()
391            }
392
393            fn filters<S>(_instance_id: &Self::InstanceId) -> ::sourcery::Filters<S, Self>
394            where
395                S: ::sourcery::store::EventStore<Id = Self::Id, Metadata = Self::Metadata>,
396            {
397                #filters_body
398            }
399        }
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use syn::parse_quote;
406
407    use super::*;
408
409    /// Normalise token output by removing whitespace.
410    fn compact(tokens: &TokenStream2) -> String {
411        tokens
412            .to_string()
413            .chars()
414            .filter(|c| !c.is_whitespace())
415            .collect()
416    }
417
418    #[test]
419    /// Verifies path parsing for `id = Type` syntax.
420    fn type_path_parses_name_value_path() {
421        let meta: syn::Meta = parse_quote!(id = String);
422        let parsed = TypePath::from_meta(&meta).unwrap();
423        assert_eq!(parsed.0, parse_quote!(String));
424    }
425
426    #[test]
427    /// Ensures non-path values are rejected for `TypePath`.
428    fn type_path_rejects_non_path_value() {
429        let meta: syn::Meta = parse_quote!(id = "String");
430        let err = TypePath::from_meta(&meta).unwrap_err();
431        assert!(err.to_string().contains("expected `key = Type`"));
432    }
433
434    #[test]
435    /// Confirms default kind and event enum names are generated.
436    fn generate_aggregate_impl_uses_default_kind_and_event_enum() {
437        let input: DeriveInput = parse_quote! {
438            #[aggregate(id = String, error = String, events(FundsDeposited))]
439            pub struct Account {
440                balance: i64,
441            }
442        };
443
444        let expanded = derive_aggregate_impl(&input);
445        let compact = compact(&expanded);
446
447        assert!(compact.contains("enumAccountEvent"));
448        assert!(compact.contains("impl::sourcery::AggregateforAccount"));
449        assert!(compact.contains("constKIND:&'staticstr=\"account\""));
450    }
451
452    #[test]
453    /// Confirms explicit kind and enum overrides are honored.
454    fn generate_aggregate_impl_respects_kind_and_event_enum_overrides() {
455        let input: DeriveInput = parse_quote! {
456            #[aggregate(
457                id = String,
458                error = String,
459                events(FundsDeposited),
460                kind = "bank-account",
461                event_enum = "BankAccountEvent"
462            )]
463            pub struct Account {
464                balance: i64,
465            }
466        };
467
468        let expanded = derive_aggregate_impl(&input);
469        let compact = compact(&expanded);
470
471        assert!(compact.contains("enumBankAccountEvent"));
472        assert!(compact.contains("constKIND:&'staticstr=\"bank-account\""));
473    }
474
475    #[test]
476    /// Ensures empty event lists yield a compile-time error.
477    fn generate_aggregate_impl_emits_error_on_empty_events_list() {
478        let input: DeriveInput = parse_quote! {
479            #[aggregate(id = String, error = String, events())]
480            pub struct Account;
481        };
482
483        let expanded = derive_aggregate_impl(&input);
484        let compact = compact(&expanded);
485
486        assert!(compact.contains("events(...)mustcontainatleastoneeventtype"));
487    }
488
489    #[test]
490    /// Confirms default kind for projections (no attributes required).
491    fn generate_projection_impl_uses_default_kind() {
492        let input: DeriveInput = parse_quote! {
493            pub struct AccountLedger;
494        };
495
496        let expanded = derive_projection_impl(&input);
497        let compact = compact(&expanded);
498
499        assert!(compact.contains("impl::sourcery::ProjectionforAccountLedger"));
500        assert!(compact.contains("constKIND:&'staticstr=\"account-ledger\""));
501    }
502
503    #[test]
504    /// Confirms projection kind override is honored.
505    fn generate_projection_impl_respects_kind_override() {
506        let input: DeriveInput = parse_quote! {
507            #[projection(kind = "custom-ledger")]
508            pub struct AccountLedger;
509        };
510
511        let expanded = derive_projection_impl(&input);
512        let compact = compact(&expanded);
513
514        assert!(compact.contains("constKIND:&'staticstr=\"custom-ledger\""));
515    }
516
517    #[test]
518    /// Confirms events(...) generates a [`ProjectionFilters`] impl with
519    /// defaults.
520    fn generate_projection_impl_with_events_generates_projection_filters() {
521        let input: DeriveInput = parse_quote! {
522            #[projection(events(FundsDeposited, FundsWithdrawn))]
523            pub struct AccountLedger;
524        };
525
526        let expanded = derive_projection_impl(&input);
527        let compact = compact(&expanded);
528
529        assert!(compact.contains("impl::sourcery::ProjectionFiltersforAccountLedger"));
530        assert!(compact.contains("typeId=String"));
531        assert!(compact.contains("typeInstanceId=()"));
532        assert!(compact.contains("typeMetadata=()"));
533        assert!(compact.contains("event::<FundsDeposited>()"));
534        assert!(compact.contains("event::<FundsWithdrawn>()"));
535    }
536
537    #[test]
538    /// Confirms projection filter type overrides are honored.
539    fn generate_projection_impl_respects_projection_filter_type_overrides() {
540        let input: DeriveInput = parse_quote! {
541            #[projection(
542                id = uuid::Uuid,
543                instance_id = String,
544                metadata = EventMetadata,
545                events(FundsDeposited)
546            )]
547            pub struct AccountLedger;
548        };
549
550        let expanded = derive_projection_impl(&input);
551        let compact = compact(&expanded);
552
553        assert!(compact.contains("typeId=uuid::Uuid"));
554        assert!(compact.contains("typeInstanceId=String"));
555        assert!(compact.contains("typeMetadata=EventMetadata"));
556    }
557}