1#![allow(clippy::option_if_let_else)]
4#![allow(clippy::needless_continue)]
5
6use darling::{FromDeriveInput, FromMeta, util::PathList};
7use heck::{ToKebabCase, ToUpperCamelCase};
8use proc_macro::TokenStream;
9use proc_macro2::TokenStream as TokenStream2;
10use quote::{ToTokens, quote};
11use syn::{DeriveInput, Ident, Path, parse_macro_input, parse_quote};
12
13#[allow(clippy::doc_markdown, reason = "false positive")]
14fn path_to_pascal_ident(path: &Path) -> Ident {
16 let mut combined = String::new();
17 for (index, segment) in path.segments.iter().enumerate() {
18 if index > 0 {
19 combined.push('_');
20 }
21 combined.push_str(&segment.ident.to_string());
22 }
23 let pascal = combined.to_upper_camel_case();
24 let span = path
25 .segments
26 .last()
27 .map_or_else(proc_macro2::Span::call_site, |segment| segment.ident.span());
28 Ident::new(&pascal, span)
29}
30
31fn parse_name_value_type(item: &syn::Meta) -> darling::Result<syn::Type> {
33 let error = || darling::Error::unsupported_shape("expected `key = Type`");
34 let syn::Meta::NameValue(nv) = item else {
35 return Err(error());
36 };
37 syn::parse2(nv.value.to_token_stream()).map_err(|_| error())
38}
39
40fn default_kind(ident: &Ident, kind: Option<String>) -> String {
42 kind.unwrap_or_else(|| ident.to_string().to_kebab_case())
43}
44
45#[derive(Debug, Clone)]
47struct TypePath(Path);
48
49impl FromMeta for TypePath {
50 fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
51 let ty = parse_name_value_type(item)?;
52 match ty {
53 syn::Type::Path(type_path) if type_path.qself.is_none() => Ok(Self(type_path.path)),
54 _ => Err(darling::Error::unsupported_shape("expected `key = Type`")),
55 }
56 }
57}
58
59#[derive(Debug, Clone)]
61struct TypeExpr(syn::Type);
62
63impl FromMeta for TypeExpr {
64 fn from_meta(item: &syn::Meta) -> darling::Result<Self> {
65 parse_name_value_type(item).map(Self)
66 }
67}
68
69#[derive(Debug, FromDeriveInput)]
71#[darling(attributes(aggregate), supports(struct_any))]
72struct AggregateArgs {
73 ident: Ident,
74 vis: syn::Visibility,
75 id: TypePath,
76 error: TypePath,
77 events: PathList,
78 #[darling(default)]
79 kind: Option<String>,
80 #[darling(default)]
81 event_enum: Option<String>,
82 #[darling(default)]
83 derives: Option<PathList>,
84}
85
86#[derive(Debug, FromDeriveInput)]
88#[darling(attributes(projection), supports(struct_any))]
89struct ProjectionArgs {
90 ident: Ident,
91 #[darling(default)]
92 kind: Option<String>,
93 #[darling(default)]
94 id: Option<TypeExpr>,
95 #[darling(default)]
96 instance_id: Option<TypeExpr>,
97 #[darling(default)]
98 metadata: Option<TypeExpr>,
99 #[darling(default)]
100 events: PathList,
101}
102
103struct EventSpec<'a> {
105 path: &'a Path,
106 variant: Ident,
107}
108
109impl<'a> EventSpec<'a> {
110 fn new(path: &'a Path) -> Self {
112 Self {
113 path,
114 variant: path_to_pascal_ident(path),
115 }
116 }
117}
118
119fn parse_or_error<T, F>(input: &DeriveInput, f: F) -> TokenStream2
121where
122 T: FromDeriveInput,
123 F: FnOnce(T) -> TokenStream2,
124{
125 match T::from_derive_input(input) {
126 Ok(args) => f(args),
127 Err(err) => err.write_errors(),
128 }
129}
130
131#[proc_macro_derive(Aggregate, attributes(aggregate))]
175pub fn derive_aggregate(input: TokenStream) -> TokenStream {
176 let input = parse_macro_input!(input as DeriveInput);
177
178 derive_aggregate_impl(&input).into()
179}
180
181fn derive_aggregate_impl(input: &DeriveInput) -> TokenStream2 {
183 parse_or_error::<AggregateArgs, _>(input, |args| generate_aggregate_impl(args, input))
184}
185
186fn generate_aggregate_impl(args: AggregateArgs, input: &DeriveInput) -> TokenStream2 {
188 let event_specs: Vec<EventSpec<'_>> = args.events.iter().map(EventSpec::new).collect();
189
190 if event_specs.is_empty() {
191 return darling::Error::custom("events(...) must contain at least one event type")
192 .with_span(&input.ident)
193 .write_errors();
194 }
195
196 let struct_name = &args.ident;
197 let struct_vis = &args.vis;
198 let id_type = &args.id.0;
199 let error_type = &args.error.0;
200 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
201
202 let kind = default_kind(struct_name, args.kind);
203
204 let event_enum_name = args.event_enum.map_or_else(
205 || Ident::new(&format!("{struct_name}Event"), struct_name.span()),
206 |name| Ident::new(&name, struct_name.span()),
207 );
208
209 let event_types: Vec<&Path> = event_specs.iter().map(|spec| spec.path).collect();
210 let variant_names: Vec<&Ident> = event_specs.iter().map(|spec| &spec.variant).collect();
211
212 let derives = if let Some(user_derives) = &args.derives {
214 let user_paths: Vec<&Path> = user_derives.iter().collect();
215 quote! { #[derive(Clone, #(#user_paths),*)] }
216 } else {
217 quote! { #[derive(Clone)] }
218 };
219
220 let expanded = quote! {
221 #derives
222 #struct_vis enum #event_enum_name {
223 #(#variant_names(#event_types)),*
224 }
225
226 impl ::sourcery::event::EventKind for #event_enum_name {
227 fn kind(&self) -> &'static str {
228 match self {
229 #(Self::#variant_names(_) => #event_types::KIND),*
230 }
231 }
232 }
233
234 impl ::serde::Serialize for #event_enum_name {
235 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
236 where
237 S: ::serde::Serializer,
238 {
239 match self {
240 #(Self::#variant_names(inner) => ::serde::Serialize::serialize(inner, serializer)),*
241 }
242 }
243 }
244
245 impl ::sourcery::ProjectionEvent for #event_enum_name {
246 const EVENT_KINDS: &'static [&'static str] = &[#(#event_types::KIND),*];
247
248 fn from_stored<S: ::sourcery::store::EventStore>(
249 stored: &::sourcery::store::StoredEvent<S::Id, S::Position, S::Data, S::Metadata>,
250 store: &S,
251 ) -> Result<Self, ::sourcery::event::EventDecodeError<S::Error>> {
252 match stored.kind() {
253 #(#event_types::KIND => Ok(Self::#variant_names(
254 store.decode_event(stored).map_err(::sourcery::event::EventDecodeError::Store)?
255 )),)*
256 _ => Err(::sourcery::event::EventDecodeError::UnknownKind {
257 kind: stored.kind().to_string(),
258 expected: Self::EVENT_KINDS,
259 }),
260 }
261 }
262 }
263
264 #(
265 impl From<#event_types> for #event_enum_name {
266 fn from(event: #event_types) -> Self {
267 Self::#variant_names(event)
268 }
269 }
270 )*
271
272 impl #impl_generics ::sourcery::Aggregate for #struct_name #ty_generics #where_clause {
273 const KIND: &'static str = #kind;
274 type Event = #event_enum_name;
275 type Error = #error_type;
276 type Id = #id_type;
277
278 fn apply(&mut self, event: &Self::Event) {
279 match event {
280 #(#event_enum_name::#variant_names(e) => ::sourcery::Apply::apply(self, e)),*
281 }
282 }
283 }
284 };
285
286 expanded
287}
288
289#[proc_macro_derive(Projection, attributes(projection))]
322pub fn derive_projection(input: TokenStream) -> TokenStream {
323 let input = parse_macro_input!(input as DeriveInput);
324
325 derive_projection_impl(&input).into()
326}
327
328fn derive_projection_impl(input: &DeriveInput) -> TokenStream2 {
330 parse_or_error::<ProjectionArgs, _>(input, |args| generate_projection_impl(args, input))
331}
332
333fn generate_projection_impl(args: ProjectionArgs, input: &DeriveInput) -> TokenStream2 {
335 let struct_name = &args.ident;
336 let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
337 let kind = default_kind(struct_name, args.kind);
338
339 let projection_impl = quote! {
340 impl #impl_generics ::sourcery::Projection for #struct_name #ty_generics #where_clause {
341 const KIND: &'static str = #kind;
342 }
343 };
344
345 let auto_projection_filters = !args.events.is_empty()
346 || args.id.is_some()
347 || args.instance_id.is_some()
348 || args.metadata.is_some();
349
350 if !auto_projection_filters {
351 return projection_impl;
352 }
353
354 let id_ty = args.id.map_or_else(|| parse_quote!(String), |ty| ty.0);
355 let instance_id_ty = args.instance_id.map_or_else(|| parse_quote!(()), |ty| ty.0);
356 let metadata_ty = args.metadata.map_or_else(|| parse_quote!(()), |ty| ty.0);
357 let subscribed_events: Vec<&Path> = args.events.iter().collect();
358
359 let filters_body = if subscribed_events.is_empty() {
360 quote! { ::sourcery::Filters::new() }
361 } else {
362 quote! { ::sourcery::Filters::new() #(.event::<#subscribed_events>())* }
363 };
364
365 let projection_filters_where = if let Some(where_clause) = where_clause {
366 let predicates = &where_clause.predicates;
367 quote! {
368 where
369 #predicates,
370 #struct_name #ty_generics: ::core::default::Default
371 }
372 } else {
373 quote! {
374 where
375 #struct_name #ty_generics: ::core::default::Default
376 }
377 };
378
379 quote! {
380 #projection_impl
381
382 impl #impl_generics ::sourcery::ProjectionFilters for #struct_name #ty_generics
383 #projection_filters_where
384 {
385 type Id = #id_ty;
386 type InstanceId = #instance_id_ty;
387 type Metadata = #metadata_ty;
388
389 fn init(_instance_id: &Self::InstanceId) -> Self {
390 Self::default()
391 }
392
393 fn filters<S>(_instance_id: &Self::InstanceId) -> ::sourcery::Filters<S, Self>
394 where
395 S: ::sourcery::store::EventStore<Id = Self::Id, Metadata = Self::Metadata>,
396 {
397 #filters_body
398 }
399 }
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use syn::parse_quote;
406
407 use super::*;
408
409 fn compact(tokens: &TokenStream2) -> String {
411 tokens
412 .to_string()
413 .chars()
414 .filter(|c| !c.is_whitespace())
415 .collect()
416 }
417
418 #[test]
419 fn type_path_parses_name_value_path() {
421 let meta: syn::Meta = parse_quote!(id = String);
422 let parsed = TypePath::from_meta(&meta).unwrap();
423 assert_eq!(parsed.0, parse_quote!(String));
424 }
425
426 #[test]
427 fn type_path_rejects_non_path_value() {
429 let meta: syn::Meta = parse_quote!(id = "String");
430 let err = TypePath::from_meta(&meta).unwrap_err();
431 assert!(err.to_string().contains("expected `key = Type`"));
432 }
433
434 #[test]
435 fn generate_aggregate_impl_uses_default_kind_and_event_enum() {
437 let input: DeriveInput = parse_quote! {
438 #[aggregate(id = String, error = String, events(FundsDeposited))]
439 pub struct Account {
440 balance: i64,
441 }
442 };
443
444 let expanded = derive_aggregate_impl(&input);
445 let compact = compact(&expanded);
446
447 assert!(compact.contains("enumAccountEvent"));
448 assert!(compact.contains("impl::sourcery::AggregateforAccount"));
449 assert!(compact.contains("constKIND:&'staticstr=\"account\""));
450 }
451
452 #[test]
453 fn generate_aggregate_impl_respects_kind_and_event_enum_overrides() {
455 let input: DeriveInput = parse_quote! {
456 #[aggregate(
457 id = String,
458 error = String,
459 events(FundsDeposited),
460 kind = "bank-account",
461 event_enum = "BankAccountEvent"
462 )]
463 pub struct Account {
464 balance: i64,
465 }
466 };
467
468 let expanded = derive_aggregate_impl(&input);
469 let compact = compact(&expanded);
470
471 assert!(compact.contains("enumBankAccountEvent"));
472 assert!(compact.contains("constKIND:&'staticstr=\"bank-account\""));
473 }
474
475 #[test]
476 fn generate_aggregate_impl_emits_error_on_empty_events_list() {
478 let input: DeriveInput = parse_quote! {
479 #[aggregate(id = String, error = String, events())]
480 pub struct Account;
481 };
482
483 let expanded = derive_aggregate_impl(&input);
484 let compact = compact(&expanded);
485
486 assert!(compact.contains("events(...)mustcontainatleastoneeventtype"));
487 }
488
489 #[test]
490 fn generate_projection_impl_uses_default_kind() {
492 let input: DeriveInput = parse_quote! {
493 pub struct AccountLedger;
494 };
495
496 let expanded = derive_projection_impl(&input);
497 let compact = compact(&expanded);
498
499 assert!(compact.contains("impl::sourcery::ProjectionforAccountLedger"));
500 assert!(compact.contains("constKIND:&'staticstr=\"account-ledger\""));
501 }
502
503 #[test]
504 fn generate_projection_impl_respects_kind_override() {
506 let input: DeriveInput = parse_quote! {
507 #[projection(kind = "custom-ledger")]
508 pub struct AccountLedger;
509 };
510
511 let expanded = derive_projection_impl(&input);
512 let compact = compact(&expanded);
513
514 assert!(compact.contains("constKIND:&'staticstr=\"custom-ledger\""));
515 }
516
517 #[test]
518 fn generate_projection_impl_with_events_generates_projection_filters() {
521 let input: DeriveInput = parse_quote! {
522 #[projection(events(FundsDeposited, FundsWithdrawn))]
523 pub struct AccountLedger;
524 };
525
526 let expanded = derive_projection_impl(&input);
527 let compact = compact(&expanded);
528
529 assert!(compact.contains("impl::sourcery::ProjectionFiltersforAccountLedger"));
530 assert!(compact.contains("typeId=String"));
531 assert!(compact.contains("typeInstanceId=()"));
532 assert!(compact.contains("typeMetadata=()"));
533 assert!(compact.contains("event::<FundsDeposited>()"));
534 assert!(compact.contains("event::<FundsWithdrawn>()"));
535 }
536
537 #[test]
538 fn generate_projection_impl_respects_projection_filter_type_overrides() {
540 let input: DeriveInput = parse_quote! {
541 #[projection(
542 id = uuid::Uuid,
543 instance_id = String,
544 metadata = EventMetadata,
545 events(FundsDeposited)
546 )]
547 pub struct AccountLedger;
548 };
549
550 let expanded = derive_projection_impl(&input);
551 let compact = compact(&expanded);
552
553 assert!(compact.contains("typeId=uuid::Uuid"));
554 assert!(compact.contains("typeInstanceId=String"));
555 assert!(compact.contains("typeMetadata=EventMetadata"));
556 }
557}