torn_api_codegen/model/
enum.rs

1use std::collections::HashSet;
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::Ident;
7
8use crate::{
9    model::WarningReporter,
10    openapi::{
11        parameter::OpenApiParameterSchema,
12        r#type::{OpenApiType, OpenApiVariants},
13    },
14};
15
16use super::{object::PrimitiveType, Model, ResolvedSchema};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum EnumRepr {
20    U8,
21    U32,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum EnumVariantTupleValue {
26    Ref { ty_name: String },
27    ArrayOfRefs { ty_name: String },
28    Primitive(PrimitiveType),
29    Enum { name: String, inner: Enum },
30}
31
32impl EnumVariantTupleValue {
33    pub fn from_schema(
34        name: &str,
35        schema: &OpenApiType,
36        warnings: WarningReporter,
37    ) -> Option<Self> {
38        match schema {
39            OpenApiType {
40                ref_path: Some(path),
41                ..
42            } => Some(Self::Ref {
43                ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
44            }),
45            OpenApiType {
46                r#type: Some("array"),
47                items: Some(items),
48                ..
49            } => {
50                let OpenApiType {
51                    ref_path: Some(path),
52                    ..
53                } = items.as_ref()
54                else {
55                    return None;
56                };
57                Some(Self::ArrayOfRefs {
58                    ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
59                })
60            }
61            OpenApiType {
62                r#type: Some("string"),
63                format: None,
64                r#enum: None,
65                ..
66            } => Some(Self::Primitive(PrimitiveType::String)),
67            OpenApiType {
68                r#type: Some("string"),
69                format: None,
70                r#enum: Some(_),
71                ..
72            } => {
73                let name = format!("{name}Variant");
74                Some(Self::Enum {
75                    inner: Enum::from_schema(&name, schema, warnings.child(name.clone()))?,
76                    name,
77                })
78            }
79            OpenApiType {
80                r#type: Some("integer"),
81                format: Some("int64"),
82                ..
83            } => Some(Self::Primitive(PrimitiveType::I64)),
84            OpenApiType {
85                r#type: Some("integer"),
86                format: Some("int32"),
87                ..
88            } => Some(Self::Primitive(PrimitiveType::I32)),
89            OpenApiType {
90                r#type: Some("number"),
91                format: Some("float") | None,
92                ..
93            } => Some(Self::Primitive(PrimitiveType::Float)),
94            _ => None,
95        }
96    }
97
98    pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
99        match self {
100            Self::Ref { ty_name } => {
101                let ty = format_ident!("{ty_name}");
102                quote! { crate::models::#ty }
103            }
104            Self::ArrayOfRefs { ty_name } => {
105                let ty = format_ident!("{ty_name}");
106                quote! { Vec<crate::models::#ty> }
107            }
108            Self::Primitive(PrimitiveType::I64) => quote! { i64 },
109            Self::Primitive(PrimitiveType::I32) => quote! { i32 },
110            Self::Primitive(PrimitiveType::Float) => quote! { f32 },
111            Self::Primitive(PrimitiveType::String) => quote! { String },
112            Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
113            Self::Primitive(PrimitiveType::Bool) => quote! { bool },
114            Self::Enum { name, .. } => {
115                let path = ns.get_ident();
116                let ty_name = format_ident!("{name}");
117                quote! {
118                    #path::#ty_name,
119                }
120            }
121        }
122    }
123
124    pub fn name(&self) -> String {
125        match self {
126            Self::Ref { ty_name } => ty_name.clone(),
127            Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
128            Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
129            Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
130            Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
131            Self::Primitive(PrimitiveType::String) => "String".to_owned(),
132            Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
133            Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
134            Self::Enum { .. } => "Variant".to_owned(),
135        }
136    }
137
138    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
139        match self {
140            Self::Primitive(_) => true,
141            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
142                .models
143                .get(ty_name)
144                .map(|f| f.is_display(resolved))
145                .unwrap_or_default(),
146            Self::Enum { inner, .. } => inner.is_display(resolved),
147        }
148    }
149
150    pub fn codegen_display(&self) -> TokenStream {
151        match self {
152            Self::ArrayOfRefs { .. } => quote! {
153                write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
154            },
155            _ => quote! {
156                write!(f, "{}", value)
157            },
158        }
159    }
160
161    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
162        match self {
163            Self::Primitive(PrimitiveType::Float) => false,
164            Self::Primitive(_) => true,
165            Self::Enum { inner, .. } => inner.is_comparable(resolved),
166            Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
167                .models
168                .get(ty_name)
169                .map(|m| matches!(m, Model::Newtype(_)))
170                .unwrap_or_default(),
171        }
172    }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum EnumVariantValue {
177    Repr(u32),
178    String { rename: Option<String> },
179    Tuple(Vec<EnumVariantTupleValue>),
180}
181
182impl Default for EnumVariantValue {
183    fn default() -> Self {
184        Self::String { rename: None }
185    }
186}
187
188impl EnumVariantValue {
189    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
190        match self {
191            Self::Repr(_) | Self::String { .. } => true,
192            Self::Tuple(val) => {
193                val.len() == 1
194                    && val
195                        .iter()
196                        .next()
197                        .map(|v| v.is_display(resolved))
198                        .unwrap_or_default()
199            }
200        }
201    }
202
203    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
204        match self {
205            Self::Repr(_) | Self::String { .. } => true,
206            Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
207        }
208    }
209
210    pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
211        let variant = format_ident!("{name}");
212
213        match self {
214            Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
215            Self::String { rename } => {
216                let name = rename.as_deref().unwrap_or(name);
217                Some(quote! { Self::#variant => write!(f, #name) })
218            }
219            Self::Tuple(values) if values.len() == 1 => {
220                let rhs = values.first().unwrap().codegen_display();
221                Some(quote! { Self::#variant(value) => #rhs })
222            }
223            _ => None,
224        }
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Default)]
229pub struct EnumVariant {
230    pub name: String,
231    pub description: Option<String>,
232    pub value: EnumVariantValue,
233}
234
235pub struct EnumNamespace<'e> {
236    r#enum: &'e Enum,
237    ident: Option<Ident>,
238    elements: Vec<TokenStream>,
239    top_level_elements: Vec<TokenStream>,
240}
241
242impl EnumNamespace<'_> {
243    pub fn get_ident(&mut self) -> Ident {
244        self.ident
245            .get_or_insert_with(|| {
246                let name = self.r#enum.name.to_snake_case();
247                format_ident!("{name}")
248            })
249            .clone()
250    }
251
252    pub fn push_element(&mut self, el: TokenStream) {
253        self.elements.push(el);
254    }
255
256    pub fn push_top_level(&mut self, el: TokenStream) {
257        self.top_level_elements.push(el);
258    }
259
260    pub fn codegen(mut self) -> Option<TokenStream> {
261        if self.elements.is_empty() && self.top_level_elements.is_empty() {
262            None
263        } else {
264            let top_level = &self.top_level_elements;
265            let mut output = quote! {
266                #(#top_level)*
267            };
268
269            if !self.elements.is_empty() {
270                let ident = self.get_ident();
271                let elements = self.elements;
272                output.extend(quote! {
273                    pub mod #ident {
274                        #(#elements)*
275                    }
276                });
277            }
278
279            Some(output)
280        }
281    }
282}
283
284impl EnumVariant {
285    pub fn codegen(
286        &self,
287        ns: &mut EnumNamespace,
288        resolved: &ResolvedSchema,
289    ) -> Option<TokenStream> {
290        let doc = self.description.as_ref().map(|d| {
291            quote! {
292                #[doc = #d]
293            }
294        });
295
296        let name = format_ident!("{}", self.name);
297
298        match &self.value {
299            EnumVariantValue::Repr(repr) => Some(quote! {
300                #doc
301                #name = #repr
302            }),
303            EnumVariantValue::String { rename } => {
304                let serde_attr = rename.as_ref().map(|r| {
305                    quote! {
306                        #[serde(rename = #r)]
307                    }
308                });
309
310                Some(quote! {
311                    #doc
312                    #serde_attr
313                    #name
314                })
315            }
316            EnumVariantValue::Tuple(values) => {
317                let mut val_tys = Vec::with_capacity(values.len());
318
319                if let [value] = values.as_slice() {
320                    let enum_name = format_ident!("{}", ns.r#enum.name);
321                    let ty_name = value.type_name(ns);
322
323                    ns.push_top_level(quote! {
324                        impl From<#ty_name> for #enum_name {
325                            fn from(value: #ty_name) -> Self {
326                                Self::#name(value)
327                            }
328                        }
329                    });
330                }
331
332                for value in values {
333                    let ty_name = value.type_name(ns);
334
335                    if let EnumVariantTupleValue::Enum { inner, .. } = &value {
336                        ns.push_element(inner.codegen(resolved)?);
337                    }
338
339                    val_tys.push(ty_name);
340                }
341
342                Some(quote! {
343                    #name(#(#val_tys),*)
344                })
345            }
346        }
347    }
348
349    pub fn codegen_display(&self) -> Option<TokenStream> {
350        self.value.codegen_display(&self.name)
351    }
352}
353
354#[derive(Debug, Clone, PartialEq, Eq, Default)]
355pub struct Enum {
356    pub name: String,
357    pub description: Option<String>,
358    pub repr: Option<EnumRepr>,
359    pub copy: bool,
360    pub untagged: bool,
361    pub variants: Vec<EnumVariant>,
362}
363
364impl Enum {
365    pub fn from_schema(
366        name: &str,
367        schema: &OpenApiType,
368        warnings: WarningReporter,
369    ) -> Option<Self> {
370        let mut result = Enum {
371            name: name.to_owned(),
372            description: schema.description.as_deref().map(ToOwned::to_owned),
373            copy: true,
374            ..Default::default()
375        };
376
377        match &schema.r#enum {
378            Some(OpenApiVariants::Int(int_variants)) => {
379                result.repr = Some(EnumRepr::U32);
380                result.variants = int_variants
381                    .iter()
382                    .copied()
383                    .map(|i| EnumVariant {
384                        name: format!("Variant{i}"),
385                        value: EnumVariantValue::Repr(i as u32),
386                        ..Default::default()
387                    })
388                    .collect();
389            }
390            Some(OpenApiVariants::Str(str_variants)) => {
391                result.variants = str_variants
392                    .iter()
393                    .copied()
394                    .map(|s| {
395                        let transformed = s.replace('&', "And").to_upper_camel_case();
396                        EnumVariant {
397                            value: EnumVariantValue::String {
398                                rename: (transformed != s).then(|| s.to_owned()),
399                            },
400                            name: transformed,
401                            ..Default::default()
402                        }
403                    })
404                    .collect();
405
406                let mut visited = HashSet::with_capacity(result.variants.len());
407                let mut new_variants = Vec::with_capacity(result.variants.len());
408                for variant in result.variants {
409                    if visited.contains(&variant.name) {
410                        warnings.push(format!("Duplicate case `{}`", variant.name));
411                    } else {
412                        visited.insert(variant.name.clone());
413                        new_variants.push(variant);
414                    }
415                }
416
417                result.variants = new_variants;
418            }
419            None => return None,
420        }
421
422        Some(result)
423    }
424
425    pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
426        let mut result = Self {
427            name: name.to_owned(),
428            copy: true,
429            ..Default::default()
430        };
431
432        for var in schema.r#enum.as_ref()? {
433            let transformed = var.to_upper_camel_case();
434            result.variants.push(EnumVariant {
435                value: EnumVariantValue::String {
436                    rename: (transformed != *var).then(|| transformed.clone()),
437                },
438                name: transformed,
439                ..Default::default()
440            });
441        }
442
443        Some(result)
444    }
445
446    pub fn from_one_of(
447        name: &str,
448        schemas: &[OpenApiType],
449        warnings: WarningReporter,
450    ) -> Option<Self> {
451        let mut result = Self {
452            name: name.to_owned(),
453            untagged: true,
454            ..Default::default()
455        };
456
457        for schema in schemas {
458            let value =
459                EnumVariantTupleValue::from_schema(name, schema, warnings.child(name.to_owned()))?;
460            let name = value.name();
461
462            result.variants.push(EnumVariant {
463                name,
464                value: EnumVariantValue::Tuple(vec![value]),
465                ..Default::default()
466            });
467        }
468
469        // HACK: idk
470        let shared: Vec<_> = result
471            .variants
472            .iter_mut()
473            .filter(|v| v.name == "Variant")
474            .collect();
475        if shared.len() >= 2 {
476            for (idx, variant) in shared.into_iter().enumerate() {
477                let label = idx + 1;
478                variant.name = format!("Variant{label}");
479                if let EnumVariantValue::Tuple(values) = &mut variant.value {
480                    if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
481                    {
482                        inner.name.push_str(&label.to_string());
483                        name.push_str(&label.to_string());
484                    }
485                }
486            }
487        }
488        Some(result)
489    }
490
491    pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
492        self.variants.iter().all(|v| v.value.is_display(resolved))
493    }
494
495    pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
496        self.variants
497            .iter()
498            .all(|v| v.value.is_comparable(resolved))
499    }
500
501    pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
502        let repr = self.repr.map(|r| match r {
503            EnumRepr::U8 => quote! { #[repr(u8)] },
504            EnumRepr::U32 => quote! { #[repr(u32)] },
505        });
506        let name = format_ident!("{}", self.name);
507        let desc = self.description.as_ref().map(|d| {
508            quote! {
509                #[doc = #d]
510            }
511        });
512
513        let mut ns = EnumNamespace {
514            r#enum: self,
515            ident: None,
516            elements: Default::default(),
517            top_level_elements: Default::default(),
518        };
519
520        let is_display = self.is_display(resolved);
521
522        let mut display = Vec::with_capacity(self.variants.len());
523        let mut variants = Vec::with_capacity(self.variants.len());
524        for variant in &self.variants {
525            variants.push(variant.codegen(&mut ns, resolved)?);
526
527            if is_display {
528                display.push(variant.codegen_display()?);
529            }
530        }
531
532        let mut derives = vec![];
533
534        if self.repr.is_some() {
535            derives.push(quote! { serde_repr::Deserialize_repr });
536        } else {
537            derives.push(quote! { serde::Deserialize });
538        }
539
540        if self.copy {
541            derives.push(quote! { Copy });
542        }
543
544        if self.is_comparable(resolved) {
545            derives.push(quote! { Eq, Hash });
546        }
547
548        let serde_attr = self.untagged.then(|| {
549            quote! {
550                #[serde(untagged)]
551            }
552        });
553
554        let display = is_display.then(|| {
555            quote! {
556                impl std::fmt::Display for #name {
557                    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
558                        match self {
559                            #(#display),*
560                        }
561                    }
562                }
563            }
564        });
565
566        let module = ns.codegen();
567
568        Some(quote! {
569            #desc
570            #[derive(Debug, Clone, PartialEq, #(#derives),*)]
571            #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
572            #serde_attr
573            #repr
574            pub enum #name {
575                #(#variants),*
576            }
577            #display
578
579            #module
580        })
581    }
582}
583
584#[cfg(test)]
585mod test {
586    use super::*;
587
588    use crate::openapi::schema::test::get_schema;
589
590    #[test]
591    fn is_display() {
592        let schema = get_schema();
593        let resolved = ResolvedSchema::from_open_api(&schema);
594
595        let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
596        assert!(torn_selection_name.is_display(&resolved));
597    }
598}