tatara_rust_composite/
lib.rs1use serde::{Deserialize, Serialize};
49use tatara_rust_ast::{AstError, CompileToCrate, CrateScaffold, Ident, ToRustTokens};
50use tatara_rust_derive::{PerFieldDeriveSpec, PerVariantDeriveSpec, ProcDeriveSpec};
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
53pub struct CompositeDeriveSpec {
54 pub bundle_name: Ident,
56 pub members: Vec<CompositeMember>,
58}
59
60#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
64#[serde(tag = "kind", rename_all = "kebab-case")]
65pub enum CompositeMember {
66 Simple(ProcDeriveSpec),
67 PerField(PerFieldDeriveSpec),
68 PerVariant(PerVariantDeriveSpec),
69}
70
71impl CompositeDeriveSpec {
72 fn fn_name(&self) -> String {
73 let s = &self.bundle_name.0;
74 let mut out = String::from("derive_");
75 for (i, c) in s.chars().enumerate() {
76 if c.is_uppercase() {
77 if i > 0 {
78 out.push('_');
79 }
80 out.extend(c.to_lowercase());
81 } else {
82 out.push(c);
83 }
84 }
85 out
86 }
87}
88
89impl CompileToCrate for CompositeDeriveSpec {
90 fn compile_to_crate(&self, crate_name: &str) -> Result<CrateScaffold, AstError> {
91 let mut s = CrateScaffold::new(crate_name, "0.1.0");
92 s.add_file("Cargo.toml", render_cargo_toml(crate_name));
93 s.add_file("src/lib.rs", render_lib_rs(self)?);
94 Ok(s)
95 }
96}
97
98fn render_cargo_toml(crate_name: &str) -> String {
99 tatara_rust_ast::render_proc_macro_cargo_toml(
100 crate_name,
101 "Composite derive proc-macro — fans one #[derive(...)] out to N inner Specs.",
102 )
103}
104
105fn render_lib_rs(spec: &CompositeDeriveSpec) -> Result<String, AstError> {
112 let bundle = &spec.bundle_name.0;
113 let fn_name = spec.fn_name();
114
115 let mut closures = String::new();
118 let mut calls = String::new();
119 for (i, m) in spec.members.iter().enumerate() {
120 let cname = format!("__member_{i}");
121 closures.push_str(&format!(
122 " let {cname} = |input: &syn::DeriveInput| -> proc_macro2::TokenStream {{\n"
123 ));
124 closures.push_str(&render_member_body(m));
125 closures.push_str(" };\n");
126 calls.push_str(&format!(" let __out_{i} = {cname}(&input);\n"));
127 }
128 let stitched = (0..spec.members.len())
129 .map(|i| format!("#__out_{i}"))
130 .collect::<Vec<_>>()
131 .join(" ");
132
133 let mut out = String::new();
134 out.push_str("// GENERATED by tatara-rust-composite::CompositeDeriveSpec.\n");
135 out.push_str("use proc_macro::TokenStream;\n");
136 out.push_str("use quote::quote;\n");
137 out.push_str("use syn::parse_macro_input;\n\n");
138 out.push_str(&format!("#[proc_macro_derive({bundle})]\n"));
139 out.push_str(&format!(
140 "pub fn {fn_name}(input: TokenStream) -> TokenStream {{\n"
141 ));
142 out.push_str(" let input = parse_macro_input!(input as syn::DeriveInput);\n");
143 out.push_str(&closures);
144 out.push_str(&calls);
145 for i in 0..spec.members.len() {
146 let _ = i;
147 }
148 out.push_str(&format!(
149 " let expanded = quote! {{ {stitched} }};\n"
150 ));
151 out.push_str(" TokenStream::from(expanded)\n");
152 out.push_str("}\n");
153
154 Ok(out)
155}
156
157fn render_member_body(m: &CompositeMember) -> String {
161 match m {
162 CompositeMember::Simple(spec) => {
163 let body = spec.impl_template.to_rust_tokens().to_string();
166 let body = body
167 .replace(tatara_rust_derive::SENTINEL_SELF_TYPE, "#name")
168 .replace("# __SELF_NAME__", "#name")
169 .replace(tatara_rust_derive::SENTINEL_SELF_NAME, "#name");
170 format!(
171 r#" let name = &input.ident;
172 quote! {{ {body} }}
173"#
174 )
175 }
176 CompositeMember::PerField(spec) => render_per_field_body(spec),
177 CompositeMember::PerVariant(spec) => render_per_variant_body(spec),
178 }
179}
180
181fn render_per_field_body(spec: &PerFieldDeriveSpec) -> String {
182 let impl_open = match &spec.trait_ref {
183 None => "impl #self_name".to_string(),
184 Some(t) => format!("impl {t} for #self_name"),
185 };
186 let method_ident_let = match &spec.method_name_template {
187 None => String::new(),
188 Some(tpl) => format!(
189 " let method_ident = quote::format_ident!(\"{tpl}\", field_name.to_string());\n"
190 ),
191 };
192 let prelude = spec.impl_prelude.as_deref().unwrap_or("");
193 let tpl = &spec.per_field_template;
194
195 let mut out = String::new();
196 out.push_str(" let self_name = &input.ident;\n");
197 out.push_str(" let fields = match &input.data {\n");
198 out.push_str(" syn::Data::Struct(syn::DataStruct { fields: syn::Fields::Named(named), .. }) => &named.named,\n");
199 out.push_str(" _ => return syn::Error::new_spanned(self_name, \"composite per-field member needs named-fields struct\").to_compile_error(),\n");
200 out.push_str(" };\n");
201 out.push_str(" let per_field = fields.iter().map(|f| {\n");
202 out.push_str(" let field_name = f.ident.as_ref().expect(\"named\");\n");
203 out.push_str(" let field_ty = &f.ty;\n");
204 out.push_str(&method_ident_let);
205 out.push_str(" quote! { ");
206 out.push_str(tpl);
207 out.push_str(" }\n");
208 out.push_str(" });\n");
209 out.push_str(&format!(
210 " quote! {{ {impl_open} {{ {prelude} #(#per_field)* }} }}\n"
211 ));
212 out
213}
214
215fn render_per_variant_body(spec: &PerVariantDeriveSpec) -> String {
216 let impl_open = match &spec.trait_ref {
217 None => "impl #self_name".to_string(),
218 Some(t) => format!("impl {t} for #self_name"),
219 };
220 let method_ident_let = match &spec.method_name_template {
221 None => String::new(),
222 Some(tpl) => format!(
223 " let method_ident = quote::format_ident!(\"{tpl}\", variant_name.to_string());\n"
224 ),
225 };
226 let prelude = spec.impl_prelude.as_deref().unwrap_or("");
227 let tpl = &spec.per_variant_template;
228
229 let mut out = String::new();
230 out.push_str(" let self_name = &input.ident;\n");
231 out.push_str(" let variants = match &input.data {\n");
232 out.push_str(" syn::Data::Enum(syn::DataEnum { variants, .. }) => variants,\n");
233 out.push_str(" _ => return syn::Error::new_spanned(self_name, \"composite per-variant member needs an enum\").to_compile_error(),\n");
234 out.push_str(" };\n");
235 out.push_str(" let per_variant = variants.iter().map(|v| {\n");
236 out.push_str(" let variant_name = &v.ident;\n");
237 out.push_str(" let variant_shape_arm = match &v.fields {\n");
238 out.push_str(" syn::Fields::Named(_) => quote! { Self::#variant_name { .. } },\n");
239 out.push_str(" syn::Fields::Unnamed(_) => quote! { Self::#variant_name(..) },\n");
240 out.push_str(" syn::Fields::Unit => quote! { Self::#variant_name },\n");
241 out.push_str(" };\n");
242 out.push_str(&method_ident_let);
243 out.push_str(" quote! { ");
244 out.push_str(tpl);
245 out.push_str(" }\n");
246 out.push_str(" });\n");
247 out.push_str(&format!(
248 " quote! {{ {impl_open} {{ {prelude} #(#per_variant)* }} }}\n"
249 ));
250 out
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256 use tatara_rust_derive::{PerFieldDeriveSpec, PerFieldTarget};
257
258 fn accessor_bundle() -> CompositeDeriveSpec {
259 let getter = PerFieldDeriveSpec {
260 trait_name: Ident::new("AccessorGetter"),
261 target: PerFieldTarget::NamedStruct,
262 trait_ref: None,
263 per_field_template:
264 "pub fn #field_name(&self) -> &#field_ty { &self.#field_name }".into(),
265 method_name_template: None,
266 impl_prelude: None,
267 skip_fields: vec![],
268 field_attribute: None,
269 };
270 let setter = PerFieldDeriveSpec {
271 trait_name: Ident::new("AccessorSetter"),
272 target: PerFieldTarget::NamedStruct,
273 trait_ref: None,
274 per_field_template:
275 "pub fn #method_ident(&mut self, v: #field_ty) { self.#field_name = v; }".into(),
276 method_name_template: Some("set_{}".into()),
277 impl_prelude: None,
278 skip_fields: vec![],
279 field_attribute: None,
280 };
281 CompositeDeriveSpec {
282 bundle_name: Ident::new("Accessor"),
283 members: vec![
284 CompositeMember::PerField(getter),
285 CompositeMember::PerField(setter),
286 ],
287 }
288 }
289
290 #[test]
291 fn compiles_to_lib_and_cargo() {
292 let s = accessor_bundle().compile_to_crate("accessor-derive").unwrap();
293 let files = s.to_files();
294 assert!(files.contains_key("Cargo.toml"));
295 assert!(files.contains_key("src/lib.rs"));
296 }
297
298 #[test]
299 fn lib_rs_emits_one_proc_macro_for_bundle() {
300 let s = accessor_bundle().compile_to_crate("a").unwrap();
301 let lib = s.to_files().get("src/lib.rs").unwrap().clone();
302 assert_eq!(
304 lib.matches("#[proc_macro_derive(Accessor)]").count(),
305 1,
306 "expected one outer derive, got: {lib}"
307 );
308 }
309
310 #[test]
311 fn lib_rs_creates_one_closure_per_member() {
312 let s = accessor_bundle().compile_to_crate("a").unwrap();
313 let lib = s.to_files().get("src/lib.rs").unwrap().clone();
314 assert!(lib.contains("__member_0"));
315 assert!(lib.contains("__member_1"));
316 assert!(lib.contains("__out_0"));
317 assert!(lib.contains("__out_1"));
318 }
319
320 #[test]
321 fn lib_rs_stitches_member_outputs() {
322 let s = accessor_bundle().compile_to_crate("a").unwrap();
323 let lib = s.to_files().get("src/lib.rs").unwrap().clone();
324 assert!(lib.contains("quote! { #__out_0 #__out_1 }"));
325 }
326
327 #[test]
328 fn serde_roundtrip() {
329 let s = accessor_bundle();
330 let j = serde_json::to_string(&s).unwrap();
331 let back: CompositeDeriveSpec = serde_json::from_str(&j).unwrap();
332 assert_eq!(s, back);
333 }
334}