torn_api_codegen/model/
union.rs

1use heck::ToSnakeCase;
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4
5use crate::openapi::path::OpenApiResponseBody;
6
7use super::WarningReporter;
8
9#[derive(Debug, Clone)]
10pub struct Union {
11    pub name: String,
12    pub members: Vec<String>,
13}
14
15impl Union {
16    pub fn from_schema(
17        name: &str,
18        schema: &OpenApiResponseBody,
19        warnings: WarningReporter,
20    ) -> Option<Self> {
21        let members = match schema {
22            OpenApiResponseBody::Union { any_of } => {
23                let mut members = Vec::with_capacity(any_of.len());
24                for l in any_of {
25                    let path = l.ref_path.to_owned();
26                    if members.contains(&path) {
27                        warnings.push(format!("Duplicate member: {path}"));
28                    } else {
29                        members.push(path);
30                    }
31                }
32                members
33            }
34            _ => return None,
35        };
36        let name = name.to_owned();
37
38        Some(Self { name, members })
39    }
40
41    pub fn codegen(&self) -> Option<TokenStream> {
42        let name = format_ident!("{}", self.name);
43        let mut variants = Vec::new();
44
45        for member in &self.members {
46            let variant_name = member.strip_prefix("#/components/schemas/")?;
47            let accessor_name = format_ident!("{}", variant_name.to_snake_case());
48            let ty_name = format_ident!("{}", variant_name);
49            variants.push(quote! {
50                pub fn #accessor_name(&self) -> Result<crate::models::#ty_name, serde_json::Error> {
51                    self.deserialize()
52                }
53            });
54        }
55
56        Some(quote! {
57            #[derive(Debug, Clone, serde::Deserialize)]
58            pub struct #name(serde_json::Value);
59
60            impl #name {
61                pub fn deserialize<'de, T>(&'de self) -> Result<T, serde_json::Error>
62                where
63                    T: serde::Deserialize<'de>,
64                {
65                    T::deserialize(&self.0)
66                }
67
68                #(#variants)*
69            }
70        })
71    }
72}