torn_api_codegen/model/
enum.rs

1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::Ident;
5
6use crate::openapi::{
7    parameter::OpenApiParameterSchema,
8    r#type::{OpenApiType, OpenApiVariants},
9};
10
11use super::{object::PrimitiveType, Model, ResolvedSchema};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EnumRepr {
15    U8,
16    U32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum EnumVariantTupleValue {
21    Ref { ty_name: String },
22    ArrayOfRefs { ty_name: String },
23    Primitive(PrimitiveType),
24    Enum { name: String, inner: Enum },
25}
26
27impl EnumVariantTupleValue {
28    pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
29        match schema {
30            OpenApiType {
31                ref_path: Some(path),
32                ..
33            } => Some(Self::Ref {
34                ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
35            }),
36            OpenApiType {
37                r#type: Some("array"),
38                items: Some(items),
39                ..
40            } => {
41                let OpenApiType {
42                    ref_path: Some(path),
43                    ..
44                } = items.as_ref()
45                else {
46                    return None;
47                };
48                Some(Self::ArrayOfRefs {
49                    ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
50                })
51            }
52            OpenApiType {
53                r#type: Some("string"),
54                format: None,
55                r#enum: None,
56                ..
57            } => Some(Self::Primitive(PrimitiveType::String)),
58            OpenApiType {
59                r#type: Some("string"),
60                format: None,
61                r#enum: Some(_),
62                ..
63            } => {
64                let name = format!("{name}Variant");
65                Some(Self::Enum {
66                    inner: Enum::from_schema(&name, schema)?,
67                    name,
68                })
69            }
70            OpenApiType {
71                r#type: Some("integer"),
72                format: Some("int64"),
73                ..
74            } => Some(Self::Primitive(PrimitiveType::I64)),
75            OpenApiType {
76                r#type: Some("integer"),
77                format: Some("int32"),
78                ..
79            } => Some(Self::Primitive(PrimitiveType::I32)),
80            OpenApiType {
81                r#type: Some("number"),
82                format: Some("float") | None,
83                ..
84            } => Some(Self::Primitive(PrimitiveType::Float)),
85            _ => None,
86        }
87    }
88
89    pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
90        match self {
91            Self::Ref { ty_name } => {
92                let ty = format_ident!("{ty_name}");
93                quote! { crate::models::#ty }
94            }
95            Self::ArrayOfRefs { ty_name } => {
96                let ty = format_ident!("{ty_name}");
97                quote! { Vec<crate::models::#ty> }
98            }
99            Self::Primitive(PrimitiveType::I64) => quote! { i64 },
100            Self::Primitive(PrimitiveType::I32) => quote! { i32 },
101            Self::Primitive(PrimitiveType::Float) => quote! { f32 },
102            Self::Primitive(PrimitiveType::String) => quote! { String },
103            Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
104            Self::Primitive(PrimitiveType::Bool) => quote! { bool },
105            Self::Enum { name, .. } => {
106                let path = ns.get_ident();
107                let ty_name = format_ident!("{name}");
108                quote! {
109                    #path::#ty_name,
110                }
111            }
112        }
113    }
114
115    pub fn name(&self) -> String {
116        match self {
117            Self::Ref { ty_name } => ty_name.clone(),
118            Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
119            Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
120            Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
121            Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
122            Self::Primitive(PrimitiveType::String) => "String".to_owned(),
123            Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
124            Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
125            Self::Enum { .. } => "Variant".to_owned(),
126        }
127    }
128
129    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
130        match self {
131            Self::Primitive(_) => true,
132            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
133                .models
134                .get(ty_name)
135                .map(|f| f.is_display(resolved))
136                .unwrap_or_default(),
137            Self::Enum { inner, .. } => inner.is_display(resolved),
138        }
139    }
140
141    pub fn codegen_display(&self) -> TokenStream {
142        match self {
143            Self::ArrayOfRefs { .. } => quote! {
144                write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
145            },
146            _ => quote! {
147                write!(f, "{}", value)
148            },
149        }
150    }
151
152    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
153        match self {
154            Self::Primitive(PrimitiveType::Float) => false,
155            Self::Primitive(_) => true,
156            Self::Enum { inner, .. } => inner.is_comparable(resolved),
157            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
158                .models
159                .get(ty_name)
160                .map(|m| matches!(m, Model::Newtype(_)))
161                .unwrap_or_default(),
162        }
163    }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
167pub enum EnumVariantValue {
168    Repr(u32),
169    String { rename: Option<String> },
170    Tuple(Vec<EnumVariantTupleValue>),
171}
172
173impl Default for EnumVariantValue {
174    fn default() -> Self {
175        Self::String { rename: None }
176    }
177}
178
179impl EnumVariantValue {
180    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
181        match self {
182            Self::Repr(_) | Self::String { .. } => true,
183            Self::Tuple(val) => {
184                val.len() == 1
185                    && val
186                        .iter()
187                        .next()
188                        .map(|v| v.is_display(resolved))
189                        .unwrap_or_default()
190            }
191        }
192    }
193
194    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
195        match self {
196            Self::Repr(_) | Self::String { .. } => true,
197            Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
198        }
199    }
200
201    pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
202        let variant = format_ident!("{name}");
203
204        match self {
205            Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
206            Self::String { rename } => {
207                let name = rename.as_deref().unwrap_or(name);
208                Some(quote! { Self::#variant => write!(f, #name) })
209            }
210            Self::Tuple(values) if values.len() == 1 => {
211                let rhs = values.first().unwrap().codegen_display();
212                Some(quote! { Self::#variant(value) => #rhs })
213            }
214            _ => None,
215        }
216    }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Default)]
220pub struct EnumVariant {
221    pub name: String,
222    pub description: Option<String>,
223    pub value: EnumVariantValue,
224}
225
226pub struct EnumNamespace<'e> {
227    r#enum: &'e Enum,
228    ident: Option<Ident>,
229    elements: Vec<TokenStream>,
230    top_level_elements: Vec<TokenStream>,
231}
232
233impl EnumNamespace<'_> {
234    pub fn get_ident(&mut self) -> Ident {
235        self.ident
236            .get_or_insert_with(|| {
237                let name = self.r#enum.name.to_snake_case();
238                format_ident!("{name}")
239            })
240            .clone()
241    }
242
243    pub fn push_element(&mut self, el: TokenStream) {
244        self.elements.push(el);
245    }
246
247    pub fn push_top_level(&mut self, el: TokenStream) {
248        self.top_level_elements.push(el);
249    }
250
251    pub fn codegen(mut self) -> Option<TokenStream> {
252        if self.elements.is_empty() && self.top_level_elements.is_empty() {
253            None
254        } else {
255            let top_level = &self.top_level_elements;
256            let mut output = quote! {
257                #(#top_level)*
258            };
259
260            if !self.elements.is_empty() {
261                let ident = self.get_ident();
262                let elements = self.elements;
263                output.extend(quote! {
264                    pub mod #ident {
265                        #(#elements)*
266                    }
267                });
268            }
269
270            Some(output)
271        }
272    }
273}
274
275impl EnumVariant {
276    pub fn codegen(
277        &self,
278        ns: &mut EnumNamespace,
279        resolved: &ResolvedSchema,
280    ) -> Option<TokenStream> {
281        let doc = self.description.as_ref().map(|d| {
282            quote! {
283                #[doc = #d]
284            }
285        });
286
287        let name = format_ident!("{}", self.name);
288
289        match &self.value {
290            EnumVariantValue::Repr(repr) => Some(quote! {
291                #doc
292                #name = #repr
293            }),
294            EnumVariantValue::String { rename } => {
295                let serde_attr = rename.as_ref().map(|r| {
296                    quote! {
297                        #[serde(rename = #r)]
298                    }
299                });
300
301                Some(quote! {
302                    #doc
303                    #serde_attr
304                    #name
305                })
306            }
307            EnumVariantValue::Tuple(values) => {
308                let mut val_tys = Vec::with_capacity(values.len());
309
310                if let [value] = values.as_slice() {
311                    let enum_name = format_ident!("{}", ns.r#enum.name);
312                    let ty_name = value.type_name(ns);
313
314                    ns.push_top_level(quote! {
315                        impl From<#ty_name> for #enum_name {
316                            fn from(value: #ty_name) -> Self {
317                                Self::#name(value)
318                            }
319                        }
320                    });
321                }
322
323                for value in values {
324                    let ty_name = value.type_name(ns);
325
326                    if let EnumVariantTupleValue::Enum { inner, .. } = &value {
327                        ns.push_element(inner.codegen(resolved)?);
328                    }
329
330                    val_tys.push(ty_name);
331                }
332
333                Some(quote! {
334                    #name(#(#val_tys),*)
335                })
336            }
337        }
338    }
339
340    pub fn codegen_display(&self) -> Option<TokenStream> {
341        self.value.codegen_display(&self.name)
342    }
343}
344
345#[derive(Debug, Clone, PartialEq, Eq, Default)]
346pub struct Enum {
347    pub name: String,
348    pub description: Option<String>,
349    pub repr: Option<EnumRepr>,
350    pub copy: bool,
351    pub untagged: bool,
352    pub variants: Vec<EnumVariant>,
353}
354
355impl Enum {
356    pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
357        let mut result = Enum {
358            name: name.to_owned(),
359            description: schema.description.as_deref().map(ToOwned::to_owned),
360            copy: true,
361            ..Default::default()
362        };
363
364        match &schema.r#enum {
365            Some(OpenApiVariants::Int(int_variants)) => {
366                result.repr = Some(EnumRepr::U32);
367                result.variants = int_variants
368                    .iter()
369                    .copied()
370                    .map(|i| EnumVariant {
371                        name: format!("Variant{i}"),
372                        value: EnumVariantValue::Repr(i as u32),
373                        ..Default::default()
374                    })
375                    .collect();
376            }
377            Some(OpenApiVariants::Str(str_variants)) => {
378                result.variants = str_variants
379                    .iter()
380                    .copied()
381                    .map(|s| {
382                        let transformed = s.replace('&', "And").to_upper_camel_case();
383                        EnumVariant {
384                            value: EnumVariantValue::String {
385                                rename: (transformed != s).then(|| s.to_owned()),
386                            },
387                            name: transformed,
388                            ..Default::default()
389                        }
390                    })
391                    .collect();
392            }
393            None => return None,
394        }
395
396        Some(result)
397    }
398
399    pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
400        let mut result = Self {
401            name: name.to_owned(),
402            copy: true,
403            ..Default::default()
404        };
405
406        for var in schema.r#enum.as_ref()? {
407            let transformed = var.to_upper_camel_case();
408            result.variants.push(EnumVariant {
409                value: EnumVariantValue::String {
410                    rename: (transformed != *var).then(|| transformed.clone()),
411                },
412                name: transformed,
413                ..Default::default()
414            });
415        }
416
417        Some(result)
418    }
419
420    pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
421        let mut result = Self {
422            name: name.to_owned(),
423            untagged: true,
424            ..Default::default()
425        };
426
427        for schema in schemas {
428            let value = EnumVariantTupleValue::from_schema(name, schema)?;
429            let name = value.name();
430
431            result.variants.push(EnumVariant {
432                name,
433                value: EnumVariantValue::Tuple(vec![value]),
434                ..Default::default()
435            });
436        }
437
438        // HACK: idk
439        let shared: Vec<_> = result
440            .variants
441            .iter_mut()
442            .filter(|v| v.name == "Variant")
443            .collect();
444        if shared.len() >= 2 {
445            for (idx, variant) in shared.into_iter().enumerate() {
446                let label = idx + 1;
447                variant.name = format!("Variant{}", label);
448                if let EnumVariantValue::Tuple(values) = &mut variant.value {
449                    if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
450                    {
451                        inner.name.push_str(&label.to_string());
452                        name.push_str(&label.to_string());
453                    }
454                }
455            }
456        }
457        Some(result)
458    }
459
460    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
461        self.variants.iter().all(|v| v.value.is_display(resolved))
462    }
463
464    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
465        self.variants
466            .iter()
467            .all(|v| v.value.is_comparable(resolved))
468    }
469
470    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
471        let repr = self.repr.map(|r| match r {
472            EnumRepr::U8 => quote! { #[repr(u8)] },
473            EnumRepr::U32 => quote! { #[repr(u32)] },
474        });
475        let name = format_ident!("{}", self.name);
476        let desc = self.description.as_ref().map(|d| {
477            quote! {
478                #repr
479                #[doc = #d]
480            }
481        });
482
483        let mut ns = EnumNamespace {
484            r#enum: self,
485            ident: None,
486            elements: Default::default(),
487            top_level_elements: Default::default(),
488        };
489
490        let is_display = self.is_display(resolved);
491
492        let mut display = Vec::with_capacity(self.variants.len());
493        let mut variants = Vec::with_capacity(self.variants.len());
494        for variant in &self.variants {
495            variants.push(variant.codegen(&mut ns, resolved)?);
496
497            if is_display {
498                display.push(variant.codegen_display()?);
499            }
500        }
501
502        let mut derives = vec![];
503
504        if self.repr.is_some() {
505            derives.push(quote! { serde_repr::Deserialize_repr });
506        } else {
507            derives.push(quote! { serde::Deserialize });
508        }
509
510        if self.copy {
511            derives.push(quote! { Copy });
512        }
513
514        if self.is_comparable(resolved) {
515            derives.push(quote! { Eq, Hash });
516        }
517
518        let serde_attr = self.untagged.then(|| {
519            quote! {
520                #[serde(untagged)]
521            }
522        });
523
524        let display = is_display.then(|| {
525            quote! {
526                impl std::fmt::Display for #name {
527                    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
528                        match self {
529                            #(#display),*
530                        }
531                    }
532                }
533            }
534        });
535
536        let module = ns.codegen();
537
538        Some(quote! {
539            #desc
540            #[derive(Debug, Clone, PartialEq, #(#derives),*)]
541            #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
542            #serde_attr
543            pub enum #name {
544                #(#variants),*
545            }
546            #display
547
548            #module
549        })
550    }
551}
552
553#[cfg(test)]
554mod test {
555    use super::*;
556
557    use crate::openapi::schema::test::get_schema;
558
559    #[test]
560    fn is_display() {
561        let schema = get_schema();
562        let resolved = ResolvedSchema::from_open_api(&schema);
563
564        let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
565        assert!(torn_selection_name.is_display(&resolved));
566    }
567}