Skip to main content

tatara_rust_composite/
lib.rs

1//! `tatara-rust-composite` — L2 composition primitive.
2//!
3//! `CompositeDeriveSpec` bundles N inner derive Specs (any kind) under
4//! one user-facing `#[derive(<name>)]`. The emitted proc-macro
5//! crate dispatches the consumer's `DeriveInput` to each inner Spec's
6//! own emitter and concatenates the produced `TokenStream`s.
7//!
8//! Authoring shape:
9//!
10//! ```
11//! use tatara_rust_ast::{CompileToCrate, Ident};
12//! use tatara_rust_derive::{PerFieldDeriveSpec, PerFieldTarget};
13//! use tatara_rust_composite::{CompositeDeriveSpec, CompositeMember};
14//!
15//! let getter = PerFieldDeriveSpec {
16//!     trait_name: Ident::new("AccessorGetter"),
17//!     target: PerFieldTarget::NamedStruct,
18//!     trait_ref: None,
19//!     per_field_template:
20//!         "pub fn #field_name(&self) -> &#field_ty { &self.#field_name }".into(),
21//!     method_name_template: None,
22//!     impl_prelude: None,
23//!     skip_fields: vec![],
24//!     field_attribute: None,
25//! };
26//! let setter = PerFieldDeriveSpec {
27//!     trait_name: Ident::new("AccessorSetter"),
28//!     target: PerFieldTarget::NamedStruct,
29//!     trait_ref: None,
30//!     per_field_template:
31//!         "pub fn #method_ident(&mut self, v: #field_ty) { self.#field_name = v; }".into(),
32//!     method_name_template: Some("set_{}".into()),
33//!     impl_prelude: None,
34//!     skip_fields: vec![],
35//!     field_attribute: None,
36//! };
37//! let bundle = CompositeDeriveSpec {
38//!     bundle_name: Ident::new("Accessor"),
39//!     members: vec![
40//!         CompositeMember::PerField(getter),
41//!         CompositeMember::PerField(setter),
42//!     ],
43//! };
44//! let scaffold = bundle.compile_to_crate("accessor-derive").unwrap();
45//! assert!(scaffold.to_files().contains_key("src/lib.rs"));
46//! ```
47
48use serde::{Deserialize, Serialize};
49use tatara_rust_ast::{AstError, CompileToCrate, CrateScaffold, Ident, ToRustTokens};
50use tatara_rust_derive::{PerFieldDeriveSpec, PerVariantDeriveSpec, ProcDeriveSpec};
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
53pub struct CompositeDeriveSpec {
54    /// User-facing `#[derive(<bundle_name>)]`.
55    pub bundle_name: Ident,
56    /// Inner derive Specs whose emissions get concatenated.
57    pub members: Vec<CompositeMember>,
58}
59
60/// Tagged-enum dispatch over the three derive shapes we know how to
61/// emit today. Adding a 4th kind = one variant + one match arm in
62/// `member_emit_call`.
63#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
64#[serde(tag = "kind", rename_all = "kebab-case")]
65pub enum CompositeMember {
66    Simple(ProcDeriveSpec),
67    PerField(PerFieldDeriveSpec),
68    PerVariant(PerVariantDeriveSpec),
69}
70
71impl CompositeDeriveSpec {
72    fn fn_name(&self) -> String {
73        let s = &self.bundle_name.0;
74        let mut out = String::from("derive_");
75        for (i, c) in s.chars().enumerate() {
76            if c.is_uppercase() {
77                if i > 0 {
78                    out.push('_');
79                }
80                out.extend(c.to_lowercase());
81            } else {
82                out.push(c);
83            }
84        }
85        out
86    }
87}
88
89impl CompileToCrate for CompositeDeriveSpec {
90    fn compile_to_crate(&self, crate_name: &str) -> Result<CrateScaffold, AstError> {
91        let mut s = CrateScaffold::new(crate_name, "0.1.0");
92        s.add_file("Cargo.toml", render_cargo_toml(crate_name));
93        s.add_file("src/lib.rs", render_lib_rs(self)?);
94        Ok(s)
95    }
96}
97
98fn render_cargo_toml(crate_name: &str) -> String {
99    tatara_rust_ast::render_proc_macro_cargo_toml(
100        crate_name,
101        "Composite derive proc-macro — fans one #[derive(...)] out to N inner Specs.",
102    )
103}
104
105/// Compile each inner member to its own derive lib.rs, then inline the
106/// per-member dispatch into a single fn that fans the DeriveInput out
107/// to every member's emit logic. To keep the spike simple, we embed
108/// each member's lib.rs body verbatim — without the `#[proc_macro_derive]`
109/// attribute (only the outer bundle gets it) — and concatenate the
110/// resulting TokenStream pieces.
111fn render_lib_rs(spec: &CompositeDeriveSpec) -> Result<String, AstError> {
112    let bundle = &spec.bundle_name.0;
113    let fn_name = spec.fn_name();
114
115    // For each inner member, emit a closure that builds the impl shape.
116    // The closure takes `&DeriveInput`, returns `proc_macro2::TokenStream`.
117    let mut closures = String::new();
118    let mut calls = String::new();
119    for (i, m) in spec.members.iter().enumerate() {
120        let cname = format!("__member_{i}");
121        closures.push_str(&format!(
122            "    let {cname} = |input: &syn::DeriveInput| -> proc_macro2::TokenStream {{\n"
123        ));
124        closures.push_str(&render_member_body(m));
125        closures.push_str("    };\n");
126        calls.push_str(&format!("    let __out_{i} = {cname}(&input);\n"));
127    }
128    let stitched = (0..spec.members.len())
129        .map(|i| format!("#__out_{i}"))
130        .collect::<Vec<_>>()
131        .join(" ");
132
133    let mut out = String::new();
134    out.push_str("// GENERATED by tatara-rust-composite::CompositeDeriveSpec.\n");
135    out.push_str("use proc_macro::TokenStream;\n");
136    out.push_str("use quote::quote;\n");
137    out.push_str("use syn::parse_macro_input;\n\n");
138    out.push_str(&format!("#[proc_macro_derive({bundle})]\n"));
139    out.push_str(&format!(
140        "pub fn {fn_name}(input: TokenStream) -> TokenStream {{\n"
141    ));
142    out.push_str("    let input = parse_macro_input!(input as syn::DeriveInput);\n");
143    out.push_str(&closures);
144    out.push_str(&calls);
145    for i in 0..spec.members.len() {
146        let _ = i;
147    }
148    out.push_str(&format!(
149        "    let expanded = quote! {{ {stitched} }};\n"
150    ));
151    out.push_str("    TokenStream::from(expanded)\n");
152    out.push_str("}\n");
153
154    Ok(out)
155}
156
157/// Render a single inner Spec's emit logic as a closure body. Mirrors
158/// the body of each Spec's own `derive_*` fn (sans the outer
159/// `#[proc_macro_derive]` attribute + the parse_macro_input call).
160fn render_member_body(m: &CompositeMember) -> String {
161    match m {
162        CompositeMember::Simple(spec) => {
163            // Simple ProcDerive: emit `impl Trait for #name { items… }`.
164            // The impl is rendered into the closure body verbatim.
165            let body = spec.impl_template.to_rust_tokens().to_string();
166            let body = body
167                .replace(tatara_rust_derive::SENTINEL_SELF_TYPE, "#name")
168                .replace("# __SELF_NAME__", "#name")
169                .replace(tatara_rust_derive::SENTINEL_SELF_NAME, "#name");
170            format!(
171                r#"        let name = &input.ident;
172        quote! {{ {body} }}
173"#
174            )
175        }
176        CompositeMember::PerField(spec) => render_per_field_body(spec),
177        CompositeMember::PerVariant(spec) => render_per_variant_body(spec),
178    }
179}
180
181fn render_per_field_body(spec: &PerFieldDeriveSpec) -> String {
182    let impl_open = match &spec.trait_ref {
183        None => "impl #self_name".to_string(),
184        Some(t) => format!("impl {t} for #self_name"),
185    };
186    let method_ident_let = match &spec.method_name_template {
187        None => String::new(),
188        Some(tpl) => format!(
189            "            let method_ident = quote::format_ident!(\"{tpl}\", field_name.to_string());\n"
190        ),
191    };
192    let prelude = spec.impl_prelude.as_deref().unwrap_or("");
193    let tpl = &spec.per_field_template;
194
195    let mut out = String::new();
196    out.push_str("        let self_name = &input.ident;\n");
197    out.push_str("        let fields = match &input.data {\n");
198    out.push_str("            syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Named(named), .. }) => &named.named,\n");
199    out.push_str("            _ => return syn::Error::new_spanned(self_name, \"composite per-field member needs named-fields struct\").to_compile_error(),\n");
200    out.push_str("        };\n");
201    out.push_str("        let per_field = fields.iter().map(|f| {\n");
202    out.push_str("            let field_name = f.ident.as_ref().expect(\"named\");\n");
203    out.push_str("            let field_ty = &f.ty;\n");
204    out.push_str(&method_ident_let);
205    out.push_str("            quote! { ");
206    out.push_str(tpl);
207    out.push_str(" }\n");
208    out.push_str("        });\n");
209    out.push_str(&format!(
210        "        quote! {{ {impl_open} {{ {prelude} #(#per_field)* }} }}\n"
211    ));
212    out
213}
214
215fn render_per_variant_body(spec: &PerVariantDeriveSpec) -> String {
216    let impl_open = match &spec.trait_ref {
217        None => "impl #self_name".to_string(),
218        Some(t) => format!("impl {t} for #self_name"),
219    };
220    let method_ident_let = match &spec.method_name_template {
221        None => String::new(),
222        Some(tpl) => format!(
223            "            let method_ident = quote::format_ident!(\"{tpl}\", variant_name.to_string());\n"
224        ),
225    };
226    let prelude = spec.impl_prelude.as_deref().unwrap_or("");
227    let tpl = &spec.per_variant_template;
228
229    let mut out = String::new();
230    out.push_str("        let self_name = &input.ident;\n");
231    out.push_str("        let variants = match &input.data {\n");
232    out.push_str("            syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,\n");
233    out.push_str("            _ => return syn::Error::new_spanned(self_name, \"composite per-variant member needs an enum\").to_compile_error(),\n");
234    out.push_str("        };\n");
235    out.push_str("        let per_variant = variants.iter().map(|v| {\n");
236    out.push_str("            let variant_name = &v.ident;\n");
237    out.push_str("            let variant_shape_arm = match &v.fields {\n");
238    out.push_str("                syn::Fields::Named(_)   => quote! { Self::#variant_name { .. } },\n");
239    out.push_str("                syn::Fields::Unnamed(_) => quote! { Self::#variant_name(..) },\n");
240    out.push_str("                syn::Fields::Unit       => quote! { Self::#variant_name },\n");
241    out.push_str("            };\n");
242    out.push_str(&method_ident_let);
243    out.push_str("            quote! { ");
244    out.push_str(tpl);
245    out.push_str(" }\n");
246    out.push_str("        });\n");
247    out.push_str(&format!(
248        "        quote! {{ {impl_open} {{ {prelude} #(#per_variant)* }} }}\n"
249    ));
250    out
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256    use tatara_rust_derive::{PerFieldDeriveSpec, PerFieldTarget};
257
258    fn accessor_bundle() -> CompositeDeriveSpec {
259        let getter = PerFieldDeriveSpec {
260            trait_name: Ident::new("AccessorGetter"),
261            target: PerFieldTarget::NamedStruct,
262            trait_ref: None,
263            per_field_template:
264                "pub fn #field_name(&self) -> &#field_ty { &self.#field_name }".into(),
265            method_name_template: None,
266            impl_prelude: None,
267            skip_fields: vec![],
268            field_attribute: None,
269        };
270        let setter = PerFieldDeriveSpec {
271            trait_name: Ident::new("AccessorSetter"),
272            target: PerFieldTarget::NamedStruct,
273            trait_ref: None,
274            per_field_template:
275                "pub fn #method_ident(&mut self, v: #field_ty) { self.#field_name = v; }".into(),
276            method_name_template: Some("set_{}".into()),
277            impl_prelude: None,
278            skip_fields: vec![],
279            field_attribute: None,
280        };
281        CompositeDeriveSpec {
282            bundle_name: Ident::new("Accessor"),
283            members: vec![
284                CompositeMember::PerField(getter),
285                CompositeMember::PerField(setter),
286            ],
287        }
288    }
289
290    #[test]
291    fn compiles_to_lib_and_cargo() {
292        let s = accessor_bundle().compile_to_crate("accessor-derive").unwrap();
293        let files = s.to_files();
294        assert!(files.contains_key("Cargo.toml"));
295        assert!(files.contains_key("src/lib.rs"));
296    }
297
298    #[test]
299    fn lib_rs_emits_one_proc_macro_for_bundle() {
300        let s = accessor_bundle().compile_to_crate("a").unwrap();
301        let lib = s.to_files().get("src/lib.rs").unwrap().clone();
302        // Exactly ONE outer derive attribute, no matter how many inner members.
303        assert_eq!(
304            lib.matches("#[proc_macro_derive(Accessor)]").count(),
305            1,
306            "expected one outer derive, got: {lib}"
307        );
308    }
309
310    #[test]
311    fn lib_rs_creates_one_closure_per_member() {
312        let s = accessor_bundle().compile_to_crate("a").unwrap();
313        let lib = s.to_files().get("src/lib.rs").unwrap().clone();
314        assert!(lib.contains("__member_0"));
315        assert!(lib.contains("__member_1"));
316        assert!(lib.contains("__out_0"));
317        assert!(lib.contains("__out_1"));
318    }
319
320    #[test]
321    fn lib_rs_stitches_member_outputs() {
322        let s = accessor_bundle().compile_to_crate("a").unwrap();
323        let lib = s.to_files().get("src/lib.rs").unwrap().clone();
324        assert!(lib.contains("quote! { #__out_0 #__out_1 }"));
325    }
326
327    #[test]
328    fn serde_roundtrip() {
329        let s = accessor_bundle();
330        let j = serde_json::to_string(&s).unwrap();
331        let back: CompositeDeriveSpec = serde_json::from_str(&j).unwrap();
332        assert_eq!(s, back);
333    }
334}