torn_api_codegen/model/
enum.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::Ident;
5
6use crate::openapi::{
7    parameter::OpenApiParameterSchema,
8    r#type::{OpenApiType, OpenApiVariants},
9};
10
11use super::{object::PrimitiveType, Model, ResolvedSchema};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EnumRepr {
15    U8,
16    U32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum EnumVariantTupleValue {
21    Ref { ty_name: String },
22    ArrayOfRefs { ty_name: String },
23    Primitive(PrimitiveType),
24    Enum { name: String, inner: Enum },
25}
26
27impl EnumVariantTupleValue {
28    pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
29        match schema {
30            OpenApiType {
31                ref_path: Some(path),
32                ..
33            } => Some(Self::Ref {
34                ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
35            }),
36            OpenApiType {
37                r#type: Some("array"),
38                items: Some(items),
39                ..
40            } => {
41                let OpenApiType {
42                    ref_path: Some(path),
43                    ..
44                } = items.as_ref()
45                else {
46                    return None;
47                };
48                Some(Self::ArrayOfRefs {
49                    ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
50                })
51            }
52            OpenApiType {
53                r#type: Some("string"),
54                format: None,
55                r#enum: None,
56                ..
57            } => Some(Self::Primitive(PrimitiveType::String)),
58            OpenApiType {
59                r#type: Some("string"),
60                format: None,
61                r#enum: Some(_),
62                ..
63            } => {
64                let name = format!("{name}Variant");
65                Some(Self::Enum {
66                    inner: Enum::from_schema(&name, schema)?,
67                    name,
68                })
69            }
70            OpenApiType {
71                r#type: Some("integer"),
72                format: Some("int64"),
73                ..
74            } => Some(Self::Primitive(PrimitiveType::I64)),
75            OpenApiType {
76                r#type: Some("integer"),
77                format: Some("int32"),
78                ..
79            } => Some(Self::Primitive(PrimitiveType::I32)),
80            _ => None,
81        }
82    }
83
84    pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
85        match self {
86            Self::Ref { ty_name } => {
87                let ty = format_ident!("{ty_name}");
88                quote! { crate::models::#ty }
89            }
90            Self::ArrayOfRefs { ty_name } => {
91                let ty = format_ident!("{ty_name}");
92                quote! { Vec<crate::models::#ty> }
93            }
94            Self::Primitive(PrimitiveType::I64) => quote! { i64 },
95            Self::Primitive(PrimitiveType::I32) => quote! { i32 },
96            Self::Primitive(PrimitiveType::Float) => quote! { f32 },
97            Self::Primitive(PrimitiveType::String) => quote! { String },
98            Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
99            Self::Primitive(PrimitiveType::Bool) => quote! { bool },
100            Self::Enum { name, .. } => {
101                let path = ns.get_ident();
102                let ty_name = format_ident!("{name}");
103                quote! {
104                    #path::#ty_name,
105                }
106            }
107        }
108    }
109
110    pub fn name(&self) -> String {
111        match self {
112            Self::Ref { ty_name } => ty_name.clone(),
113            Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
114            Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
115            Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
116            Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
117            Self::Primitive(PrimitiveType::String) => "String".to_owned(),
118            Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
119            Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
120            Self::Enum { .. } => "Variant".to_owned(),
121        }
122    }
123
124    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
125        match self {
126            Self::Primitive(_) => true,
127            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
128                .models
129                .get(ty_name)
130                .map(|f| f.is_display(resolved))
131                .unwrap_or_default(),
132            Self::Enum { inner, .. } => inner.is_display(resolved),
133        }
134    }
135
136    pub fn codegen_display(&self) -> TokenStream {
137        match self {
138            Self::ArrayOfRefs { .. } => quote! {
139                write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
140            },
141            _ => quote! {
142                write!(f, "{}", value)
143            },
144        }
145    }
146
147    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
148        match self {
149            Self::Primitive(_) => true,
150            Self::Enum { inner, .. } => inner.is_comparable(resolved),
151            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
152                .models
153                .get(ty_name)
154                .map(|m| matches!(m, Model::Newtype(_)))
155                .unwrap_or_default(),
156        }
157    }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161pub enum EnumVariantValue {
162    Repr(u32),
163    String { rename: Option<String> },
164    Tuple(Vec<EnumVariantTupleValue>),
165}
166
167impl Default for EnumVariantValue {
168    fn default() -> Self {
169        Self::String { rename: None }
170    }
171}
172
173impl EnumVariantValue {
174    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
175        match self {
176            Self::Repr(_) | Self::String { .. } => true,
177            Self::Tuple(val) => {
178                val.len() == 1
179                    && val
180                        .iter()
181                        .next()
182                        .map(|v| v.is_display(resolved))
183                        .unwrap_or_default()
184            }
185        }
186    }
187
188    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
189        match self {
190            Self::Repr(_) | Self::String { .. } => true,
191            Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
192        }
193    }
194
195    pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
196        let variant = format_ident!("{name}");
197
198        match self {
199            Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
200            Self::String { rename } => {
201                let name = rename.as_deref().unwrap_or(name);
202                Some(quote! { Self::#variant => write!(f, #name) })
203            }
204            Self::Tuple(values) if values.len() == 1 => {
205                let rhs = values.first().unwrap().codegen_display();
206                Some(quote! { Self::#variant(value) => #rhs })
207            }
208            _ => None,
209        }
210    }
211}
212
213#[derive(Debug, Clone, PartialEq, Eq, Default)]
214pub struct EnumVariant {
215    pub name: String,
216    pub description: Option<String>,
217    pub value: EnumVariantValue,
218}
219
220pub struct EnumNamespace<'e> {
221    r#enum: &'e Enum,
222    ident: Option<Ident>,
223    elements: Vec<TokenStream>,
224    top_level_elements: Vec<TokenStream>,
225}
226
227impl EnumNamespace<'_> {
228    pub fn get_ident(&mut self) -> Ident {
229        self.ident
230            .get_or_insert_with(|| {
231                let name = self.r#enum.name.to_snake_case();
232                format_ident!("{name}")
233            })
234            .clone()
235    }
236
237    pub fn push_element(&mut self, el: TokenStream) {
238        self.elements.push(el);
239    }
240
241    pub fn push_top_level(&mut self, el: TokenStream) {
242        self.top_level_elements.push(el);
243    }
244
245    pub fn codegen(mut self) -> Option<TokenStream> {
246        if self.elements.is_empty() && self.top_level_elements.is_empty() {
247            None
248        } else {
249            let top_level = &self.top_level_elements;
250            let mut output = quote! {
251                #(#top_level)*
252            };
253
254            if !self.elements.is_empty() {
255                let ident = self.get_ident();
256                let elements = self.elements;
257                output.extend(quote! {
258                    pub mod #ident {
259                        #(#elements)*
260                    }
261                });
262            }
263
264            Some(output)
265        }
266    }
267}
268
269impl EnumVariant {
270    pub fn codegen(
271        &self,
272        ns: &mut EnumNamespace,
273        resolved: &ResolvedSchema,
274    ) -> Option<TokenStream> {
275        let doc = self.description.as_ref().map(|d| {
276            quote! {
277                #[doc = #d]
278            }
279        });
280
281        let name = format_ident!("{}", self.name);
282
283        match &self.value {
284            EnumVariantValue::Repr(repr) => Some(quote! {
285                #doc
286                #name = #repr
287            }),
288            EnumVariantValue::String { rename } => {
289                let serde_attr = rename.as_ref().map(|r| {
290                    quote! {
291                        #[serde(rename = #r)]
292                    }
293                });
294
295                Some(quote! {
296                    #doc
297                    #serde_attr
298                    #name
299                })
300            }
301            EnumVariantValue::Tuple(values) => {
302                let mut val_tys = Vec::with_capacity(values.len());
303
304                if let [value] = values.as_slice() {
305                    let enum_name = format_ident!("{}", ns.r#enum.name);
306                    let ty_name = value.type_name(ns);
307
308                    ns.push_top_level(quote! {
309                        impl From<#ty_name> for #enum_name {
310                            fn from(value: #ty_name) -> Self {
311                                Self::#name(value)
312                            }
313                        }
314                    });
315                }
316
317                for value in values {
318                    let ty_name = value.type_name(ns);
319
320                    if let EnumVariantTupleValue::Enum { inner, .. } = &value {
321                        ns.push_element(inner.codegen(resolved)?);
322                    }
323
324                    val_tys.push(ty_name);
325                }
326
327                Some(quote! {
328                    #name(#(#val_tys),*)
329                })
330            }
331        }
332    }
333
334    pub fn codegen_display(&self) -> Option<TokenStream> {
335        self.value.codegen_display(&self.name)
336    }
337}
338
339#[derive(Debug, Clone, PartialEq, Eq, Default)]
340pub struct Enum {
341    pub name: String,
342    pub description: Option<String>,
343    pub repr: Option<EnumRepr>,
344    pub copy: bool,
345    pub untagged: bool,
346    pub variants: Vec<EnumVariant>,
347}
348
349impl Enum {
350    pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
351        let mut result = Enum {
352            name: name.to_owned(),
353            description: schema.description.as_deref().map(ToOwned::to_owned),
354            copy: true,
355            ..Default::default()
356        };
357
358        match &schema.r#enum {
359            Some(OpenApiVariants::Int(int_variants)) => {
360                result.repr = Some(EnumRepr::U32);
361                result.variants = int_variants
362                    .iter()
363                    .copied()
364                    .map(|i| EnumVariant {
365                        name: format!("Variant{i}"),
366                        value: EnumVariantValue::Repr(i as u32),
367                        ..Default::default()
368                    })
369                    .collect();
370            }
371            Some(OpenApiVariants::Str(str_variants)) => {
372                result.variants = str_variants
373                    .iter()
374                    .copied()
375                    .map(|s| {
376                        let transformed = s.replace('&', "And").to_upper_camel_case();
377                        EnumVariant {
378                            value: EnumVariantValue::String {
379                                rename: (transformed != s).then(|| s.to_owned()),
380                            },
381                            name: transformed,
382                            ..Default::default()
383                        }
384                    })
385                    .collect();
386            }
387            None => return None,
388        }
389
390        Some(result)
391    }
392
393    pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
394        let mut result = Self {
395            name: name.to_owned(),
396            copy: true,
397            ..Default::default()
398        };
399
400        for var in schema.r#enum.as_ref()? {
401            let transformed = var.to_upper_camel_case();
402            result.variants.push(EnumVariant {
403                value: EnumVariantValue::String {
404                    rename: (transformed != *var).then(|| transformed.clone()),
405                },
406                name: transformed,
407                ..Default::default()
408            });
409        }
410
411        Some(result)
412    }
413
414    pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
415        let mut result = Self {
416            name: name.to_owned(),
417            untagged: true,
418            ..Default::default()
419        };
420
421        for schema in schemas {
422            let value = EnumVariantTupleValue::from_schema(name, schema)?;
423            let name = value.name();
424
425            result.variants.push(EnumVariant {
426                name,
427                value: EnumVariantValue::Tuple(vec![value]),
428                ..Default::default()
429            });
430        }
431
432        // HACK: idk
433        let shared: Vec<_> = result
434            .variants
435            .iter_mut()
436            .filter(|v| v.name == "Variant")
437            .collect();
438        if shared.len() >= 2 {
439            for (idx, variant) in shared.into_iter().enumerate() {
440                let label = idx + 1;
441                variant.name = format!("Variant{}", label);
442                if let EnumVariantValue::Tuple(values) = &mut variant.value {
443                    if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
444                    {
445                        inner.name.push_str(&label.to_string());
446                        name.push_str(&label.to_string());
447                    }
448                }
449            }
450        }
451        Some(result)
452    }
453
454    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
455        self.variants.iter().all(|v| v.value.is_display(resolved))
456    }
457
458    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
459        self.variants
460            .iter()
461            .all(|v| v.value.is_comparable(resolved))
462    }
463
464    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
465        let repr = self.repr.map(|r| match r {
466            EnumRepr::U8 => quote! { #[repr(u8)] },
467            EnumRepr::U32 => quote! { #[repr(u32)] },
468        });
469        let name = format_ident!("{}", self.name);
470        let desc = self.description.as_ref().map(|d| {
471            quote! {
472                #repr
473                #[doc = #d]
474            }
475        });
476
477        let mut ns = EnumNamespace {
478            r#enum: self,
479            ident: None,
480            elements: Default::default(),
481            top_level_elements: Default::default(),
482        };
483
484        let is_display = self.is_display(resolved);
485
486        let mut display = Vec::with_capacity(self.variants.len());
487        let mut variants = Vec::with_capacity(self.variants.len());
488        for variant in &self.variants {
489            variants.push(variant.codegen(&mut ns, resolved)?);
490
491            if is_display {
492                display.push(variant.codegen_display()?);
493            }
494        }
495
496        let mut derives = vec![];
497
498        if self.repr.is_some() {
499            derives.push(quote! { serde_repr::Deserialize_repr });
500        } else {
501            derives.push(quote! { serde::Deserialize });
502        }
503
504        if self.copy {
505            derives.push(quote! { Copy });
506        }
507
508        if self.is_comparable(resolved) {
509            derives.push(quote! { Eq, Hash });
510        }
511
512        let serde_attr = self.untagged.then(|| {
513            quote! {
514                #[serde(untagged)]
515            }
516        });
517
518        let display = is_display.then(|| {
519            quote! {
520                impl std::fmt::Display for #name {
521                    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
522                        match self {
523                            #(#display),*
524                        }
525                    }
526                }
527            }
528        });
529
530        let module = ns.codegen();
531
532        Some(quote! {
533            #desc
534            #[derive(Debug, Clone, PartialEq, #(#derives),*)]
535            #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
536            #serde_attr
537            pub enum #name {
538                #(#variants),*
539            }
540            #display
541
542            #module
543        })
544    }
545}
546
547#[cfg(test)]
548mod test {
549    use super::*;
550
551    use crate::openapi::schema::test::get_schema;
552
553    #[test]
554    fn is_display() {
555        let schema = get_schema();
556        let resolved = ResolvedSchema::from_open_api(&schema);
557
558        let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
559        assert!(torn_selection_name.is_display(&resolved));
560    }
561}