1#![forbid(unsafe_code)]
2pub(crate) mod error;
30pub(crate) mod fields;
31mod inv;
32mod parser;
33mod render;
34mod utils;
35
36use crate::error::generate_unsupported_compile_error;
37use crate::fields::{FieldKind, Fields};
38use crate::parser::{TemplateSegments, parse_template};
39use crate::render::generate_format_string_args;
40use darling::FromDeriveInput;
41use darling::util::{Flag, Override};
42use inv::generator::generate_str_parser;
43use proc_macro::TokenStream;
44use quote::quote;
45use std::collections::HashSet;
46use syn::{DeriveInput, parse_macro_input};
47
48#[derive(Debug, FromDeriveInput)]
49#[darling(attributes(templatia), supports(struct_named))]
50struct TemplateOpts {
51 ident: syn::Ident,
53 data: darling::ast::Data<(), syn::Field>,
55 #[darling(default)]
57 template: Override<String>,
58 #[darling(default)]
59 allow_missing_placeholders: Flag,
60 #[darling(default)]
61 empty_str_option_not_none: Flag,
62}
63
64#[proc_macro_derive(Template, attributes(templatia))]
84pub fn template_derive(input: TokenStream) -> TokenStream {
85 let ast = parse_macro_input!(input as DeriveInput);
86
87 let opts = match TemplateOpts::from_derive_input(&ast) {
88 Ok(opts) => opts,
89 Err(e) => return e.write_errors().into(),
90 };
91
92 let name = &opts.ident;
93
94 let template = match &opts.template {
95 Override::Explicit(template) => template.to_string(),
96 Override::Inherit => {
97 if let syn::Data::Struct(data_struct) = &ast.data {
98 if let syn::Fields::Named(fields_named) = &data_struct.fields {
99 fields_named
100 .named
101 .iter()
102 .filter_map(|field| field.ident.as_ref())
103 .map(|ident| format!("{0} = {{{0}}}", ident.to_string()))
104 .collect::<Vec<_>>()
105 .join("\n")
106 } else {
107 String::new()
108 }
109 } else {
110 String::new()
111 }
112 }
113 };
114
115 let marker_input = format!("{}::{}", name, template);
116 let hash = {
117 use std::hash::{DefaultHasher, Hash, Hasher};
118
119 let mut hasher = DefaultHasher::new();
120 marker_input.hash(&mut hasher);
121
122 hasher.finish()
123 };
124 let escaped_colon_marker = format!("<escaped_colon_templatia_{:x}>", hash);
125
126 let allow_missing_placeholders = opts.allow_missing_placeholders.is_present();
127 let empty_str_as_none = opts.empty_str_option_not_none.is_present();
128
129 let all_fields = if let darling::ast::Data::Struct(data_struct) = &opts.data {
130 &data_struct.fields
131 } else {
132 unreachable!()
134 };
135
136 let fields = Fields::new(all_fields);
137
138 let option_fields = fields
139 .option_fields()
140 .keys()
141 .copied()
142 .collect::<HashSet<_>>();
143
144 let segments = match parse_template(&template) {
145 Ok(segments) => segments,
146 Err(e) => {
147 let error =
148 syn::Error::new_spanned(&opts.ident, format!("Failed to parse template: {}", e));
149 return error.to_compile_error().into();
151 }
152 };
153
154 let (format_string, format_args) = generate_format_string_args(&segments, &option_fields);
155
156 let placeholder_names = segments
158 .iter()
159 .filter_map(|segment| {
160 if let TemplateSegments::Placeholder(name) = segment {
161 Some(name.trim().to_string())
162 } else {
163 None
164 }
165 })
166 .collect::<HashSet<_>>();
167
168 let str_from_parser = generate_str_parser(
169 name,
170 &fields,
171 &placeholder_names,
172 &segments,
173 allow_missing_placeholders,
174 !empty_str_as_none,
175 &escaped_colon_marker,
176 );
177
178 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
180
181 let mut new_where_clause = where_clause
182 .cloned()
183 .unwrap_or_else(|| syn::parse_quote! { where });
184
185 for field in fields.used_fields_in_template(&placeholder_names) {
186 if let Some(ident) = field.ident.as_ref() {
187 match fields.get_field_kind(ident) {
188 Some(FieldKind::Option(ty)) => {
189 new_where_clause.predicates.push(syn::parse_quote! {
190 #ty: ::std::fmt::Display + ::std::str::FromStr + ::std::cmp::PartialEq
191 });
192 new_where_clause.predicates.push(syn::parse_quote! {
193 <#ty as ::std::str::FromStr>::Err: ::std::fmt::Display
194 });
195 }
196 Some(FieldKind::Primitive(ty)) => {
197 if !allow_missing_placeholders {
198 new_where_clause.predicates.push(syn::parse_quote! {
199 #ty: ::std::fmt::Display + ::std::str::FromStr + ::std::cmp::PartialEq
200 });
201 } else {
202 new_where_clause.predicates.push(syn::parse_quote! {
203 #ty: ::std::fmt::Display + ::std::str::FromStr + ::std::cmp::PartialEq + ::std::default::Default
204 });
205 }
206 new_where_clause.predicates.push(syn::parse_quote! {
207 <#ty as ::std::str::FromStr>::Err: ::std::fmt::Display
208 });
209 }
210 Some(kind) => return generate_unsupported_compile_error(ident, kind).into(),
211 None => {
212 return generate_unsupported_compile_error(ident, &FieldKind::Unknown).into();
213 }
214 }
215 }
216 }
217
218 let where_clause = if new_where_clause.predicates.is_empty() {
219 quote! {}
220 } else {
221 quote! { #new_where_clause }
222 };
223
224 let replace_escaped_to_colon = quote! { replace(#escaped_colon_marker, ":") };
225
226 quote! {
227 impl #impl_generics ::templatia::Template for #name #ty_generics #where_clause {
228 type Error = templatia::TemplateError;
229
230 fn render_string(&self) -> String {
231 format!(#format_string, #(#format_args),*)
232 }
233
234 fn from_str(s: &str) -> Result<Self, Self::Error> {
235 use ::templatia::__private::chumsky;
236 use ::templatia::__private::chumsky::Parser;
237 use ::templatia::__private::chumsky::prelude::*;
238
239 let parser = #str_from_parser;
240 match parser.parse(s).into_result() {
241 Ok(value) => Ok(value),
242 Err(errs) => {
243 for err in &errs {
244 if let ::templatia::__private::chumsky::error::RichReason::Custom(msg) = err.reason() {
245 let m = msg.to_string();
246 const PFX_CONFLICT: &str = "__templatia_conflict__:";
247 const PFX_PARSE: &str = "__templatia_parse_type__:";
248 const PFX_PARSE_LITERAL: &str = "__templatia_parse_literal__:";
249 if let Some(rest) = m.strip_prefix(PFX_CONFLICT) {
250 if let Some((placeholder, rest)) = rest.split_once("::") {
251 if let Some((first_value, second_value)) = rest.split_once("::") {
252 return Err(::templatia::TemplateError::InconsistentValues {
253 placeholder: placeholder.#replace_escaped_to_colon.to_string(),
254 first_value: first_value.#replace_escaped_to_colon.to_string(),
255 second_value: second_value.#replace_escaped_to_colon.to_string(),
256 });
257 }
258 }
259 } else if let Some(rest) = m.strip_prefix(PFX_PARSE) {
260 if let Some((placeholder, rest)) = rest.split_once("::") {
261 if let Some((value, ty)) = rest.split_once("::") {
262 return Err(::templatia::TemplateError::ParseToType {
263 placeholder: placeholder.#replace_escaped_to_colon.to_string(),
264 value: value.#replace_escaped_to_colon.to_string(),
265 type_name: ty.#replace_escaped_to_colon.to_string(),
266 })
267 }
268 }
269 } else if let Some(rest) = m.strip_prefix(PFX_PARSE_LITERAL) {
270 if let Some((expected, got)) = rest.split_once("::") {
271 let expected_next_literal = expected.trim_matches('"')
272 .#replace_escaped_to_colon
273 .to_string();
274 let remaining_text = got.#replace_escaped_to_colon.to_string();
275
276 return Err(::templatia::TemplateError::UnexpectedInput {
277 expected_next_literal,
278 remaining_text,
279 })
280 }
281 }
282 }
283 }
284
285 let error_message = errs.into_iter()
286 .map(|err| err.to_string())
287 .collect::<Vec<_>>()
288 .join("\n");
289
290 Err(templatia::TemplateError::Parse(error_message))
291 }
292 }
293 }
294 }
295 }.into()
296}