shapely_derive/
lib.rs

1#![warn(missing_docs)]
2#![doc = include_str!("../README.md")]
3
4use unsynn::*;
5
6keyword! {
7    KPub = "pub";
8    KStruct = "struct";
9    KEnum = "enum";
10    KDoc = "doc";
11    KRepr = "repr";
12    KCrate = "crate";
13    KConst = "const";
14    KMut = "mut";
15}
16
17operator! {
18    Eq = "=";
19    Semi = ";";
20    Apostrophe = "'";
21    DoubleSemicolon = "::";
22}
23
24unsynn! {
25    enum Vis {
26        Pub(KPub),
27        PubCrate(Cons<KPub, ParenthesisGroupContaining<KCrate>>),
28    }
29
30    struct Attribute {
31        _pound: Pound,
32        body: BracketGroupContaining<AttributeInner>,
33    }
34
35    enum AttributeInner {
36        Doc(DocInner),
37        Repr(ReprInner),
38        Any(Vec<TokenTree>)
39    }
40
41    struct DocInner {
42        _kw_doc: KDoc,
43        _eq: Eq,
44        value: LiteralString,
45    }
46
47    struct ReprInner {
48        _kw_repr: KRepr,
49        attr: ParenthesisGroupContaining<Ident>,
50    }
51
52    struct Struct {
53        // Skip any doc attributes by consuming them
54        attributes: Vec<Attribute>,
55        _vis: Option<Vis>,
56        _kw_struct: KStruct,
57        name: Ident,
58        body: BraceGroupContaining<CommaDelimitedVec<StructField>>,
59    }
60
61    struct Lifetime {
62        _apostrophe: Apostrophe,
63        name: Ident,
64    }
65
66    enum Expr {
67        Integer(LiteralInteger),
68    }
69
70    enum Type {
71        Path(PathType),
72        Tuple(ParenthesisGroupContaining<CommaDelimitedVec<Box<Type>>>),
73        Slice(BracketGroupContaining<Box<Type>>),
74        Bare(BareType),
75    }
76
77    struct PathType {
78        prefix: Ident,
79        _doublesemi: DoubleSemicolon,
80        rest: Box<Type>,
81    }
82
83    struct BareType {
84        name: Ident,
85        generic_params: Option<GenericParams>,
86    }
87
88    struct GenericParams {
89        _lt: Lt,
90        params: CommaDelimitedVec<Type>,
91        _gt: Gt,
92    }
93
94    enum ConstOrMut {
95        Const(KConst),
96        Mut(KMut),
97    }
98
99    struct StructField {
100        attributes: Vec<Attribute>,
101        _vis: Option<Vis>,
102        name: Ident,
103        _colon: Colon,
104        typ: Type,
105    }
106
107    struct TupleStruct {
108        // Skip any doc attributes by consuming them
109        attributes: Vec<Attribute>,
110        _vis: Option<Vis>,
111        _kw_struct: KStruct,
112        name: Ident,
113        body: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
114    }
115
116    struct TupleField {
117        attributes: Vec<Attribute>,
118        vis: Option<Vis>,
119        typ: Type,
120    }
121
122    struct Enum {
123        // Skip any doc attributes by consuming them
124        attributes: Vec<Attribute>,
125        _pub: Option<KPub>,
126        _kw_enum: KEnum,
127        name: Ident,
128        body: BraceGroupContaining<CommaDelimitedVec<EnumVariantLike>>,
129    }
130
131    enum EnumVariantLike {
132        Unit(UnitVariant),
133        Tuple(TupleVariant),
134        Struct(StructVariant),
135    }
136
137    struct UnitVariant {
138        attributes: Vec<Attribute>,
139        name: Ident,
140    }
141
142    struct TupleVariant {
143        // Skip any doc comments on variants
144        attributes: Vec<Attribute>,
145        name: Ident,
146        _paren: ParenthesisGroupContaining<CommaDelimitedVec<TupleField>>,
147    }
148
149    struct StructVariant {
150        // Skip any doc comments on variants
151        _doc_attributes: Vec<Attribute>,
152        name: Ident,
153        _brace: BraceGroupContaining<CommaDelimitedVec<StructField>>,
154    }
155}
156
157/// Derive the [`shapely_core::Shapely`] trait for structs, tuple structs, and enums.
158///
159/// This uses unsynn, so it's light, but it _will_ choke on some Rust syntax because...
160/// there's a lot of Rust syntax.
161#[proc_macro_derive(Shapely)]
162pub fn shapely_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
163    let input = TokenStream::from(input);
164    let mut i = input.to_token_iter();
165
166    // Try to parse as struct first
167    if let Ok(parsed) = i.parse::<Struct>() {
168        return process_struct(parsed);
169    }
170    let struct_tokens_left = i.count();
171
172    // Try to parse as tuple struct
173    i = input.to_token_iter(); // Reset iterator
174    if let Ok(parsed) = i.parse::<TupleStruct>() {
175        return process_tuple_struct(parsed);
176    }
177    let tuple_struct_tokens_left = i.count();
178
179    // Try to parse as enum
180    i = input.to_token_iter(); // Reset iterator
181    if let Ok(parsed) = i.parse::<Enum>() {
182        return process_enum(parsed);
183    }
184    let enum_tokens_left = i.count();
185
186    let mut msg = format!(
187        "Could not parse input as struct, tuple struct, or enum: {}",
188        input
189    );
190
191    // Find which parsing left the fewest tokens
192    let min_tokens_left = struct_tokens_left
193        .min(tuple_struct_tokens_left)
194        .min(enum_tokens_left);
195
196    // Parse again for the one with fewest tokens left and show remaining tokens
197    if min_tokens_left == struct_tokens_left {
198        i = input.to_token_iter();
199        let err = i.parse::<Struct>().err();
200        msg = format!(
201            "{}\n====> Error parsing struct: {:?}\n====> Remaining tokens after struct parsing: {}",
202            msg,
203            err,
204            i.collect::<TokenStream>()
205        );
206    } else if min_tokens_left == tuple_struct_tokens_left {
207        i = input.to_token_iter();
208        let err = i.parse::<TupleStruct>().err();
209        msg = format!(
210            "{}\n====> Error parsing tuple struct: {:?}\n====> Remaining tokens after tuple struct parsing: {}",
211            msg,
212            err,
213            i.collect::<TokenStream>()
214        );
215    } else {
216        i = input.to_token_iter();
217        let err = i.parse::<Enum>().err();
218        msg = format!(
219            "{}\n====> Error parsing enum: {:?}\n====> Remaining tokens after enum parsing: {}",
220            msg,
221            err,
222            i.collect::<TokenStream>()
223        );
224    }
225
226    // If we get here, couldn't parse as struct, tuple struct, or enum
227    panic!("{msg}");
228}
229
230/// Processes a regular struct to implement Shapely
231///
232/// Example input:
233/// ```rust
234/// struct Blah {
235///     foo: u32,
236///     bar: String,
237/// }
238/// ```
239fn process_struct(parsed: Struct) -> proc_macro::TokenStream {
240    let struct_name = parsed.name.to_string();
241    let fields = parsed
242        .body
243        .content
244        .0
245        .iter()
246        .map(|field| field.value.name.to_string())
247        .collect::<Vec<String>>()
248        .join(", ");
249
250    // Generate the impl
251    let output = format!(
252        r#"
253            #[automatically_derived]
254            impl shapely::Shapely for {struct_name} {{
255                fn shape() -> shapely::Shape {{
256                    shapely::Shape {{
257                        name: |f, _opts| std::fmt::Write::write_str(f, "{struct_name}"),
258                        typeid: shapely::mini_typeid::of::<Self>(),
259                        layout: std::alloc::Layout::new::<Self>(),
260                        innards: shapely::Innards::Struct {{
261                            fields: shapely::struct_fields!({struct_name}, ({fields})),
262                        }},
263                        set_to_default: None,
264                        drop_in_place: Some(|ptr| unsafe {{ std::ptr::drop_in_place(ptr as *mut Self) }}),
265                    }}
266                }}
267            }}
268        "#
269    );
270    output.into_token_stream().into()
271}
272
273/// Processes a tuple struct to implement Shapely
274///
275/// Example input:
276/// ```rust
277/// struct Point(f32, f32);
278/// ```
279fn process_tuple_struct(parsed: TupleStruct) -> proc_macro::TokenStream {
280    let struct_name = parsed.name.to_string();
281
282    // Generate field names for tuple elements (0, 1, 2, etc.)
283    let fields = parsed
284        .body
285        .content
286        .0
287        .iter()
288        .enumerate()
289        .map(|(idx, _)| idx.to_string())
290        .collect::<Vec<String>>();
291
292    // Create the fields string for struct_fields! macro
293    let fields_str = fields.join(", ");
294
295    // Generate the impl
296    let output = format!(
297        r#"
298            impl shapely::Shapely for {struct_name} {{
299                fn shape() -> shapely::Shape {{
300                    shapely::Shape {{
301                        name: |f, _opts| std::fmt::Write::write_str(f, "{struct_name}"),
302                        typeid: shapely::mini_typeid::of::<Self>(),
303                        layout: std::alloc::Layout::new::<Self>(),
304                        innards: shapely::Innards::TupleStruct {{
305                            fields: shapely::struct_fields!({struct_name}, ({fields_str})),
306                        }},
307                        set_to_default: None,
308                        drop_in_place: Some(|ptr| unsafe {{ std::ptr::drop_in_place(ptr as *mut Self) }}),
309                    }}
310                }}
311            }}
312        "#
313    );
314    output.into_token_stream().into()
315}
316
317/// Processes an enum to implement Shapely
318///
319/// Example input:
320/// ```rust
321/// #[repr(u8)]
322/// enum Color {
323///     Red,
324///     Green,
325///     Blue(u8, u8),
326///     Custom { r: u8, g: u8, b: u8 }
327/// }
328/// ```
329fn process_enum(parsed: Enum) -> proc_macro::TokenStream {
330    let enum_name = parsed.name.to_string();
331
332    // Check for explicit repr attribute
333    let has_repr = parsed
334        .attributes
335        .iter()
336        .any(|attr| matches!(attr.body.content, AttributeInner::Repr(_)));
337
338    if !has_repr {
339        return r#"compile_error!("Enums must have an explicit representation (e.g. #[repr(u8)]) to be used with Shapely")"#
340            .into_token_stream()
341            .into();
342    }
343
344    // Process each variant
345    let variants = parsed
346        .body
347        .content
348        .0
349        .iter()
350        .map(|var_like| match &var_like.value {
351            EnumVariantLike::Unit(unit) => {
352                let variant_name = unit.name.to_string();
353                format!("shapely::enum_unit_variant!({enum_name}, {variant_name})")
354            }
355            EnumVariantLike::Tuple(tuple) => {
356                let variant_name = tuple.name.to_string();
357                let field_types = tuple
358                    ._paren
359                    .content
360                    .0
361                    .iter()
362                    .map(|field| field.value.typ.to_string())
363                    .collect::<Vec<String>>()
364                    .join(", ");
365
366                format!(
367                    "shapely::enum_tuple_variant!({enum_name}, {variant_name}, [{field_types}])"
368                )
369            }
370            EnumVariantLike::Struct(struct_var) => {
371                let variant_name = struct_var.name.to_string();
372                let fields = struct_var
373                    ._brace
374                    .content
375                    .0
376                    .iter()
377                    .map(|field| {
378                        let name = field.value.name.to_string();
379                        let typ = field.value.typ.to_string();
380                        format!("{name}: {typ}")
381                    })
382                    .collect::<Vec<String>>()
383                    .join(", ");
384
385                format!("shapely::enum_struct_variant!({enum_name}, {variant_name}, {{{fields}}})")
386            }
387        })
388        .collect::<Vec<String>>()
389        .join(", ");
390
391    // Extract the repr type
392    let mut repr_type = "Default"; // Default fallback
393    for attr in &parsed.attributes {
394        if let AttributeInner::Repr(repr_attr) = &attr.body.content {
395            repr_type = match repr_attr.attr.content.to_string().as_str() {
396                "u8" => "U8",
397                "u16" => "U16",
398                "u32" => "U32",
399                "u64" => "U64",
400                "usize" => "USize",
401                "i8" => "I8",
402                "i16" => "I16",
403                "i32" => "I32",
404                "i64" => "I64",
405                "isize" => "ISize",
406                _ => "Default", // Unknown repr type
407            };
408            break;
409        }
410    }
411
412    // Generate the impl
413    let output = format!(
414        r#"
415            impl shapely::Shapely for {enum_name} {{
416                fn shape() -> shapely::Shape {{
417                    shapely::Shape {{
418                        name: |f, _opts| std::fmt::Write::write_str(f, "{enum_name}"),
419                        typeid: shapely::mini_typeid::of::<Self>(),
420                        layout: std::alloc::Layout::new::<Self>(),
421                        innards: shapely::Innards::Enum {{
422                            variants: shapely::enum_variants!({enum_name}, [{variants}]),
423                            repr: shapely::EnumRepr::{repr_type},
424                        }},
425                        set_to_default: None,
426                        drop_in_place: Some(|ptr| unsafe {{ std::ptr::drop_in_place(ptr as *mut Self) }}),
427                    }}
428                }}
429            }}
430        "#
431    );
432    output.into_token_stream().into()
433}
434
435impl std::fmt::Display for Type {
436    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
437        match self {
438            Type::Path(path) => {
439                write!(f, "{}::{}", path.prefix, path.rest)
440            }
441            Type::Tuple(tuple) => {
442                write!(f, "(")?;
443                for (i, typ) in tuple.content.0.iter().enumerate() {
444                    if i > 0 {
445                        write!(f, ", ")?;
446                    }
447                    write!(f, "{}", typ.value)?;
448                }
449                write!(f, ")")
450            }
451            Type::Slice(slice) => {
452                write!(f, "[{}]", slice.content)
453            }
454            Type::Bare(ident) => {
455                write!(f, "{}", ident.name)?;
456                if let Some(generic_params) = &ident.generic_params {
457                    write!(f, "<")?;
458                    for (i, param) in generic_params.params.0.iter().enumerate() {
459                        if i > 0 {
460                            write!(f, ", ")?;
461                        }
462                        write!(f, "{}", param.value)?;
463                    }
464                    write!(f, ">")?;
465                }
466                Ok(())
467            }
468        }
469    }
470}
471
472impl std::fmt::Display for ConstOrMut {
473    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
474        match self {
475            ConstOrMut::Const(_) => write!(f, "const"),
476            ConstOrMut::Mut(_) => write!(f, "mut"),
477        }
478    }
479}
480
481impl std::fmt::Display for Lifetime {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        write!(f, "'{}", self.name)
484    }
485}
486
487impl std::fmt::Display for Expr {
488    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
489        match self {
490            Expr::Integer(int) => write!(f, "{}", int.value()),
491        }
492    }
493}