Skip to main content

specta_serde/
lib.rs

1//! [Serde](https://serde.rs) support for Specta
2//!
3//! This crate parses `#[serde(...)]` attributes and applies the necessary transformations to your types.
4//! This is possible because the Specta macro crate stores discovered macro attributes in the [specta::DataType] definition of your type.
5//!
6//! For specific attributes, refer to Serde's [official documentation](https://serde.rs/attributes.html).
7//!
8//! # Usage
9//!
10//! ## Transform a TypeCollection in-place
11//!
12//! ```ignore
13//! use specta::TypeCollection;
14//! use specta_serde::{apply, SerdeMode};
15//!
16//! let mut types = TypeCollection::default();
17//! // Add your types...
18//!
19//! // For serialization only
20//! apply(&mut types, SerdeMode::Serialize)?;
21//!
22//! // For deserialization only
23//! apply(&mut types, SerdeMode::Deserialize)?;
24//!
25//! // For both (uses common attributes, skips mode-specific ones)
26//! apply(&mut types, SerdeMode::Both)?;
27//! ```
28//!
29//! ## Transform a single DataType
30//!
31//! ```ignore
32//! use specta::DataType;
33//! use specta_serde::{apply_to_dt, SerdeMode};
34//!
35//! let dt = DataType::Primitive(specta::datatype::Primitive::String);
36//! let transformed = apply_to_dt(dt, SerdeMode::Serialize)?;
37//! ```
38//!
39//! ## Understanding SerdeMode
40//!
41//! - `SerdeMode::Serialize`: Apply transformations for serialization (Rust → JSON/etc).
42//!   Respects `skip_serializing`, `rename_serialize`, etc.
43//!
44//! - `SerdeMode::Deserialize`: Apply transformations for deserialization (JSON/etc → Rust).
45//!   Respects `skip_deserializing`, `rename_deserialize`, etc.
46//!
47//! - `SerdeMode::Both`: Apply transformations that work for both directions.
48//!   - Uses common attributes like `rename`, `rename_all`, `skip`
49//!   - Only skips fields/types that are skipped in BOTH modes
50//!   - Ignores mode-specific attributes unless they match in both modes
51//!   - Useful when you want a single type definition for bidirectional APIs
52#![cfg_attr(docsrs, feature(doc_cfg))]
53#![doc(
54    html_logo_url = "https://github.com/specta-rs/specta/raw/main/.github/logo-128.png",
55    html_favicon_url = "https://github.com/specta-rs/specta/raw/main/.github/logo-128.png"
56)]
57
58mod error;
59mod inflection;
60mod repr;
61mod serde_attrs;
62
63pub use error::Error;
64pub use repr::EnumRepr;
65pub use serde_attrs::{SerdeMode, apply_serde_transformations};
66
67use specta::TypeCollection;
68use specta::datatype::{
69    DataType, Enum, Fields, Generic, NamedReference, Primitive, Reference, RuntimeAttribute,
70    RuntimeMeta, RuntimeNestedMeta, skip_fields, skip_fields_named,
71};
72use std::collections::HashSet;
73
74/// Apply Serde attributes to a [TypeCollection] in-place.
75///
76/// This function validates all types in the collection, then applies serde transformations
77/// according to the specified mode.
78///
79/// # Modes
80///
81/// - [`SerdeMode::Serialize`]: Apply transformations for serialization (Rust → JSON/etc)
82/// - [`SerdeMode::Deserialize`]: Apply transformations for deserialization (JSON/etc → Rust)
83/// - [`SerdeMode::Both`]: Apply common transformations (useful for bidirectional APIs)
84///
85/// The validation ensures:
86/// - Map keys are valid types (string/number types)
87/// - Internally tagged enums are properly structured
88/// - Skip attributes don't result in empty enums
89///
90/// # Example
91/// ```ignore
92/// use specta_serde::{apply, SerdeMode};
93///
94/// let mut types = specta::TypeCollection::default();
95/// // For serialization only
96/// apply(&mut types, SerdeMode::Serialize)?;
97///
98/// // For both serialization and deserialization
99/// apply(&mut types, SerdeMode::Both)?;
100/// ```
101pub fn apply(types: &mut TypeCollection, mode: SerdeMode) -> Result<(), Error> {
102    // First validate all types before transformation
103    for ndt in types.into_unsorted_iter() {
104        validate_type(ndt.ty(), types, &[], &mut Default::default())?;
105    }
106
107    // Apply transformations to each type in the collection
108    let mut transform_error = None;
109    let transformed = types.clone().map(|mut ndt| {
110        // Apply serde transformations - we validated above so this should succeed
111        // Pass the type name for struct tagging
112        match serde_attrs::apply_serde_transformations_with_name(ndt.ty(), ndt.name(), mode) {
113            Ok(transformed_dt) => {
114                ndt.set_ty(transformed_dt);
115                ndt
116            }
117            Err(err) => {
118                if transform_error.is_none() {
119                    transform_error = Some(err);
120                }
121                ndt
122            }
123        }
124    });
125
126    if let Some(err) = transform_error {
127        return Err(err);
128    }
129
130    // Validate transformed types
131    for ndt in transformed.into_unsorted_iter() {
132        validate_type(ndt.ty(), &transformed, &[], &mut Default::default())?;
133    }
134
135    // Replace the original collection with the transformed one
136    *types = transformed;
137
138    Ok(())
139}
140
141/// Apply Serde attributes to a single [DataType].
142///
143/// This function takes a DataType, applies serde transformations according to the
144/// specified mode, and returns the transformed DataType.
145///
146/// # Example
147/// ```ignore
148/// let dt = DataType::Primitive(Primitive::String);
149/// let transformed = specta_serde::apply_to_dt(dt, SerdeMode::Serialize)?;
150/// ```
151pub fn apply_to_dt(dt: DataType, mode: SerdeMode) -> Result<DataType, Error> {
152    serde_attrs::apply_serde_transformations(&dt, mode)
153}
154
155/// Process a TypeCollection and return transformed types for serialization
156///
157/// This is a convenience function that creates a new TypeCollection with serde transformations
158/// applied for serialization. For in-place transformation, use [`apply`] instead.
159///
160/// # Example
161/// ```ignore
162/// let types = specta::TypeCollection::default();
163/// let ser_types = specta_serde::process_for_serialization(&types)?;
164/// ```
165#[doc(hidden)]
166pub fn process_for_serialization(types: &TypeCollection) -> Result<TypeCollection, Error> {
167    let mut cloned = types.clone();
168    apply(&mut cloned, SerdeMode::Serialize)?;
169    Ok(cloned)
170}
171
172/// Process a TypeCollection and return transformed types for deserialization
173///
174/// This is a convenience function that creates a new TypeCollection with serde transformations
175/// applied for deserialization. For in-place transformation, use [`apply`] instead.
176///
177/// # Example
178/// ```ignore
179/// let types = specta::TypeCollection::default();
180/// let de_types = specta_serde::process_for_deserialization(&types)?;
181/// ```
182#[doc(hidden)]
183pub fn process_for_deserialization(types: &TypeCollection) -> Result<TypeCollection, Error> {
184    let mut cloned = types.clone();
185    apply(&mut cloned, SerdeMode::Deserialize)?;
186    Ok(cloned)
187}
188
189/// Process types for both serialization and deserialization
190///
191/// This is a convenience function that returns separate TypeCollections for serialization
192/// and deserialization. For in-place transformation, use [`apply`] instead.
193///
194/// Returns a tuple of (serialization_types, deserialization_types)
195///
196/// # Example
197/// ```ignore
198/// let types = specta::TypeCollection::default();
199/// let (ser_types, de_types) = specta_serde::process_for_both(&types)?;
200/// ```
201#[doc(hidden)]
202pub fn process_for_both(types: &TypeCollection) -> Result<(TypeCollection, TypeCollection), Error> {
203    let ser_types = process_for_serialization(types)?;
204    let de_types = process_for_deserialization(types)?;
205    Ok((ser_types, de_types))
206}
207
208/// Internal validation function that recursively validates types
209fn validate_type(
210    dt: &DataType,
211    types: &TypeCollection,
212    generics: &[(Generic, DataType)],
213    checked_references: &mut HashSet<NamedReference>,
214) -> Result<(), Error> {
215    match dt {
216        DataType::Nullable(ty) => validate_type(ty, types, generics, checked_references)?,
217        DataType::Map(ty) => {
218            is_valid_map_key(ty.key_ty(), types, generics)?;
219            validate_type(ty.value_ty(), types, generics, checked_references)?;
220        }
221        DataType::Struct(ty) => match ty.fields() {
222            Fields::Unit => {}
223            Fields::Unnamed(ty) => {
224                for (_, ty) in skip_fields(ty.fields()) {
225                    validate_type(ty, types, generics, checked_references)?;
226                }
227            }
228            Fields::Named(ty) => {
229                for (_, (_, ty)) in skip_fields_named(ty.fields()) {
230                    validate_type(ty, types, generics, checked_references)?;
231                }
232            }
233        },
234        DataType::Enum(ty) => {
235            validate_enum(ty, types)?;
236
237            for (_variant_name, variant) in ty.variants().iter() {
238                if variant.skip() {
239                    continue;
240                }
241
242                match &variant.fields() {
243                    Fields::Unit => {}
244                    Fields::Named(variant) => {
245                        for (_, (_, ty)) in skip_fields_named(variant.fields()) {
246                            validate_type(ty, types, generics, checked_references)?;
247                        }
248                    }
249                    Fields::Unnamed(variant) => {
250                        for (_, ty) in skip_fields(variant.fields()) {
251                            validate_type(ty, types, generics, checked_references)?;
252                        }
253                    }
254                }
255            }
256        }
257        DataType::Tuple(ty) => {
258            for ty in ty.elements() {
259                validate_type(ty, types, generics, checked_references)?;
260            }
261        }
262        DataType::List(ty) => {
263            validate_type(ty.ty(), types, generics, checked_references)?;
264        }
265        DataType::Reference(Reference::Named(r)) => {
266            for (_, dt) in r.generics() {
267                validate_type(dt, types, &[], checked_references)?;
268            }
269
270            if !checked_references.contains(r) {
271                checked_references.insert(r.clone());
272                if let Some(ndt) = r.get(types) {
273                    validate_type(ndt.ty(), types, r.generics(), checked_references)?;
274                }
275            }
276        }
277        DataType::Reference(Reference::Opaque(_)) => {}
278        _ => {}
279    }
280
281    Ok(())
282}
283
284// Typescript: Must be assignable to `string | number | symbol` says Typescript.
285fn is_valid_map_key(
286    key_ty: &DataType,
287    types: &TypeCollection,
288    generics: &[(Generic, DataType)],
289) -> Result<(), Error> {
290    match key_ty {
291        DataType::Primitive(
292            Primitive::i8
293            | Primitive::i16
294            | Primitive::i32
295            | Primitive::i64
296            | Primitive::i128
297            | Primitive::isize
298            | Primitive::u8
299            | Primitive::u16
300            | Primitive::u32
301            | Primitive::u64
302            | Primitive::u128
303            | Primitive::usize
304            | Primitive::f32
305            | Primitive::f64
306            | Primitive::String
307            | Primitive::char,
308        ) => Ok(()),
309        DataType::Primitive(_) => Err(Error::InvalidMapKey),
310        // Enum of other valid types are also valid Eg. `"A" | "B"` or `"A" | 5` are valid
311        DataType::Enum(ty) => {
312            for (_variant_name, variant) in ty.variants() {
313                match &variant.fields() {
314                    Fields::Unit => {}
315                    Fields::Unnamed(item) => {
316                        if item.fields().len() > 1 {
317                            return Err(Error::InvalidMapKey);
318                        }
319
320                        // TODO: Check enum representation for untagged requirement
321                        // if *ty.repr().unwrap_or(&EnumRepr::External) != EnumRepr::Untagged {
322                        //     return Err(Error::InvalidMapKey);
323                        // }
324                    }
325                    _ => return Err(Error::InvalidMapKey),
326                }
327            }
328
329            Ok(())
330        }
331        DataType::Tuple(t) => {
332            if t.elements().is_empty() {
333                return Err(Error::InvalidMapKey);
334            }
335
336            Ok(())
337        }
338        DataType::Reference(Reference::Named(r)) => {
339            if let Some(ndt) = r.get(types) {
340                is_valid_map_key(ndt.ty(), types, r.generics())?;
341            }
342            Ok(())
343        }
344        DataType::Reference(Reference::Opaque(_)) => Ok(()),
345        DataType::Generic(g) => {
346            let ty = generics
347                .iter()
348                .find(|(ge, _)| ge == g)
349                .map(|(_, dt)| dt)
350                .expect("unable to find expected generic type"); // TODO: Proper error instead of panicking
351
352            is_valid_map_key(ty, types, &[])
353        }
354        _ => Err(Error::InvalidMapKey),
355    }
356}
357
358// Serde does not allow serializing a variant of certain enum shapes.
359fn validate_enum(e: &Enum, types: &TypeCollection) -> Result<(), Error> {
360    if matches!(get_enum_repr(e.attributes()), EnumRepr::Internal { .. }) {
361        validate_internally_tag_enum(e, types, &mut Default::default())?;
362    }
363
364    Ok(())
365}
366
367fn validate_internally_tag_enum(
368    e: &Enum,
369    types: &TypeCollection,
370    checked_references: &mut HashSet<NamedReference>,
371) -> Result<(), Error> {
372    for (_variant_name, variant) in e.variants() {
373        if variant.skip() {
374            continue;
375        }
376
377        match &variant.fields() {
378            Fields::Unit | Fields::Named(_) => {}
379            Fields::Unnamed(item) => {
380                let mut fields = skip_fields(item.fields());
381
382                let Some((_, first_field)) = fields.next() else {
383                    continue;
384                };
385
386                if fields.next().is_some() {
387                    return Err(Error::InvalidInternallyTaggedEnum);
388                }
389
390                validate_internally_tag_enum_datatype(first_field, types, checked_references)?;
391            }
392        }
393    }
394
395    Ok(())
396}
397
398fn validate_internally_tag_enum_datatype(
399    ty: &DataType,
400    types: &TypeCollection,
401    checked_references: &mut HashSet<NamedReference>,
402) -> Result<(), Error> {
403    match ty {
404        DataType::Map(_) => Ok(()),
405        DataType::Struct(ty) => match ty.fields() {
406            Fields::Unit | Fields::Named(_) => Ok(()),
407            Fields::Unnamed(unnamed) => {
408                if !is_transparent_struct(ty.attributes()) {
409                    return Err(Error::InvalidInternallyTaggedEnum);
410                }
411
412                let mut fields = skip_fields(unnamed.fields());
413
414                let Some((_, inner_field)) = fields.next() else {
415                    return Ok(());
416                };
417
418                if fields.next().is_some() {
419                    return Err(Error::InvalidInternallyTaggedEnum);
420                }
421
422                validate_internally_tag_enum_datatype(inner_field, types, checked_references)
423            }
424        },
425        DataType::Enum(ty) => match get_enum_repr(ty.attributes()) {
426            EnumRepr::Internal { .. } | EnumRepr::Adjacent { .. } => Ok(()),
427            EnumRepr::Untagged => {
428                for (_variant_name, variant) in ty.variants() {
429                    match variant.fields() {
430                        Fields::Unit | Fields::Named(_) => {}
431                        Fields::Unnamed(unnamed) => {
432                            let mut fields = skip_fields(unnamed.fields());
433
434                            let Some((_, inner_field)) = fields.next() else {
435                                continue;
436                            };
437
438                            if fields.next().is_some() {
439                                return Err(Error::InvalidInternallyTaggedEnum);
440                            }
441
442                            validate_internally_tag_enum_datatype(
443                                inner_field,
444                                types,
445                                checked_references,
446                            )?;
447                        }
448                    }
449                }
450
451                Ok(())
452            }
453            EnumRepr::External | EnumRepr::String { .. } => Err(Error::InvalidInternallyTaggedEnum),
454        },
455        DataType::Tuple(ty) if ty.elements().is_empty() => Ok(()),
456        DataType::Reference(Reference::Named(r)) => {
457            if !checked_references.contains(r) {
458                checked_references.insert(r.clone());
459                if let Some(ndt) = r.get(types) {
460                    validate_internally_tag_enum_datatype(ndt.ty(), types, checked_references)?;
461                }
462            }
463
464            Ok(())
465        }
466        DataType::Nullable(ty) => {
467            validate_internally_tag_enum_datatype(ty, types, checked_references)
468        }
469        DataType::Reference(Reference::Opaque(_)) | DataType::Generic(_) => Ok(()),
470        _ => Err(Error::InvalidInternallyTaggedEnum),
471    }
472}
473
474fn is_transparent_struct(attributes: &[RuntimeAttribute]) -> bool {
475    attributes.iter().any(|attr| {
476        if attr.path != "serde" && attr.path != "specta" {
477            return false;
478        }
479
480        match &attr.kind {
481            RuntimeMeta::Path(path) => path == "transparent",
482            RuntimeMeta::List(items) => items.iter().any(|item| {
483                matches!(item, RuntimeNestedMeta::Meta(RuntimeMeta::Path(path)) if path == "transparent")
484            }),
485            RuntimeMeta::NameValue { .. } => false,
486        }
487    })
488}
489
490/// Check if a field has the `#[serde(flatten)]` attribute
491///
492/// This is a utility function for exporters that need to handle flattened fields.
493/// It checks both `#[serde(flatten)]` and `#[specta(flatten)]` attributes.
494///
495/// # Example
496/// ```ignore
497/// use specta::datatype::Field;
498/// use specta_serde::is_field_flattened;
499///
500/// fn process_field(field: &Field) {
501///     if is_field_flattened(field) {
502///         // Handle flattened field
503///     } else {
504///         // Handle regular field
505///     }
506/// }
507/// ```
508pub fn is_field_flattened(field: &specta::datatype::Field) -> bool {
509    use specta::datatype::{RuntimeMeta, RuntimeNestedMeta};
510
511    field.attributes().iter().any(|attr| {
512        if attr.path == "serde" || attr.path == "specta" {
513            match &attr.kind {
514                RuntimeMeta::Path(path) => path == "flatten",
515                RuntimeMeta::List(items) => items.iter().any(|item| {
516                    matches!(item, RuntimeNestedMeta::Meta(RuntimeMeta::Path(path)) if path == "flatten")
517                }),
518                _ => false,
519            }
520        } else {
521            false
522        }
523    })
524}
525
526/// Get the enum representation from serde attributes
527///
528/// This function parses `#[serde(tag = "...")]`, `#[serde(content = "...")]`,
529/// and `#[serde(untagged)]` attributes to determine the enum representation.
530///
531/// Returns `EnumRepr::External` by default if no representation attributes are found.
532///
533/// # Example
534/// ```ignore
535/// use specta::datatype::Enum;
536/// use specta_serde::{get_enum_repr, EnumRepr};
537///
538/// fn process_enum(e: &Enum) {
539///     let repr = get_enum_repr(e.attributes());
540///     match repr {
541///         EnumRepr::External => { /* handle external */ },
542///         EnumRepr::Internal { tag } => { /* handle internal */ },
543///         EnumRepr::Adjacent { tag, content } => { /* handle adjacent */ },
544///         EnumRepr::Untagged => { /* handle untagged */ },
545///         _ => {}
546///     }
547/// }
548/// ```
549pub fn get_enum_repr(attributes: &[specta::datatype::RuntimeAttribute]) -> EnumRepr {
550    use specta::datatype::{RuntimeLiteral, RuntimeMeta, RuntimeNestedMeta, RuntimeValue};
551    use std::borrow::Cow;
552
553    let mut tag = None;
554    let mut content = None;
555    let mut untagged = false;
556
557    fn parse_repr_from_meta(
558        meta: &RuntimeMeta,
559        tag: &mut Option<String>,
560        content: &mut Option<String>,
561        untagged: &mut bool,
562    ) {
563        match meta {
564            RuntimeMeta::Path(path) => {
565                if path == "untagged" {
566                    *untagged = true;
567                }
568            }
569            RuntimeMeta::NameValue { key, value } => {
570                if key == "tag" {
571                    if let RuntimeValue::Literal(RuntimeLiteral::Str(t)) = value {
572                        *tag = Some(t.clone());
573                    }
574                } else if key == "content"
575                    && let RuntimeValue::Literal(RuntimeLiteral::Str(c)) = value
576                {
577                    *content = Some(c.clone());
578                }
579            }
580            RuntimeMeta::List(list) => {
581                for nested in list {
582                    if let RuntimeNestedMeta::Meta(nested_meta) = nested {
583                        parse_repr_from_meta(nested_meta, tag, content, untagged);
584                    }
585                }
586            }
587        }
588    }
589
590    for attr in attributes {
591        if attr.path == "serde" {
592            parse_repr_from_meta(&attr.kind, &mut tag, &mut content, &mut untagged);
593        }
594    }
595
596    if let (Some(tag_name), Some(content_name)) = (tag.clone(), content.clone()) {
597        EnumRepr::Adjacent {
598            tag: Cow::Owned(tag_name),
599            content: Cow::Owned(content_name),
600        }
601    } else if let Some(tag_name) = tag {
602        EnumRepr::Internal {
603            tag: Cow::Owned(tag_name),
604        }
605    } else if untagged {
606        EnumRepr::Untagged
607    } else {
608        EnumRepr::External
609    }
610}