1use std::collections::HashSet;
2
3use heck::{ToSnakeCase, ToUpperCamelCase};
4use proc_macro2::TokenStream;
5use quote::{format_ident, quote};
6use syn::Ident;
7
8use crate::{
9 model::WarningReporter,
10 openapi::{
11 parameter::OpenApiParameterSchema,
12 r#type::{OpenApiType, OpenApiVariants},
13 },
14};
15
16use super::{object::PrimitiveType, Model, ResolvedSchema};
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum EnumRepr {
20 U8,
21 U32,
22}
23
24#[derive(Debug, Clone, PartialEq, Eq)]
25pub enum EnumVariantTupleValue {
26 Ref { ty_name: String },
27 ArrayOfRefs { ty_name: String },
28 Primitive(PrimitiveType),
29 Enum { name: String, inner: Enum },
30}
31
32impl EnumVariantTupleValue {
33 pub fn from_schema(
34 name: &str,
35 schema: &OpenApiType,
36 warnings: WarningReporter,
37 ) -> Option<Self> {
38 match schema {
39 OpenApiType {
40 ref_path: Some(path),
41 ..
42 } => Some(Self::Ref {
43 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
44 }),
45 OpenApiType {
46 r#type: Some("array"),
47 items: Some(items),
48 ..
49 } => {
50 let OpenApiType {
51 ref_path: Some(path),
52 ..
53 } = items.as_ref()
54 else {
55 return None;
56 };
57 Some(Self::ArrayOfRefs {
58 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
59 })
60 }
61 OpenApiType {
62 r#type: Some("string"),
63 format: None,
64 r#enum: None,
65 ..
66 } => Some(Self::Primitive(PrimitiveType::String)),
67 OpenApiType {
68 r#type: Some("string"),
69 format: None,
70 r#enum: Some(_),
71 ..
72 } => {
73 let name = format!("{name}Variant");
74 Some(Self::Enum {
75 inner: Enum::from_schema(&name, schema, warnings.child(name.clone()))?,
76 name,
77 })
78 }
79 OpenApiType {
80 r#type: Some("integer"),
81 format: Some("int64"),
82 ..
83 } => Some(Self::Primitive(PrimitiveType::I64)),
84 OpenApiType {
85 r#type: Some("integer"),
86 format: Some("int32"),
87 ..
88 } => Some(Self::Primitive(PrimitiveType::I32)),
89 OpenApiType {
90 r#type: Some("number"),
91 format: Some("float") | None,
92 ..
93 } => Some(Self::Primitive(PrimitiveType::Float)),
94 _ => None,
95 }
96 }
97
98 pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
99 match self {
100 Self::Ref { ty_name } => {
101 let ty = format_ident!("{ty_name}");
102 quote! { crate::models::#ty }
103 }
104 Self::ArrayOfRefs { ty_name } => {
105 let ty = format_ident!("{ty_name}");
106 quote! { Vec<crate::models::#ty> }
107 }
108 Self::Primitive(PrimitiveType::I64) => quote! { i64 },
109 Self::Primitive(PrimitiveType::I32) => quote! { i32 },
110 Self::Primitive(PrimitiveType::Float) => quote! { f32 },
111 Self::Primitive(PrimitiveType::String) => quote! { String },
112 Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
113 Self::Primitive(PrimitiveType::Bool) => quote! { bool },
114 Self::Enum { name, .. } => {
115 let path = ns.get_ident();
116 let ty_name = format_ident!("{name}");
117 quote! {
118 #path::#ty_name,
119 }
120 }
121 }
122 }
123
124 pub fn name(&self) -> String {
125 match self {
126 Self::Ref { ty_name } => ty_name.clone(),
127 Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
128 Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
129 Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
130 Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
131 Self::Primitive(PrimitiveType::String) => "String".to_owned(),
132 Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
133 Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
134 Self::Enum { .. } => "Variant".to_owned(),
135 }
136 }
137
138 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
139 match self {
140 Self::Primitive(_) => true,
141 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
142 .models
143 .get(ty_name)
144 .map(|f| f.is_display(resolved))
145 .unwrap_or_default(),
146 Self::Enum { inner, .. } => inner.is_display(resolved),
147 }
148 }
149
150 pub fn codegen_display(&self) -> TokenStream {
151 match self {
152 Self::ArrayOfRefs { .. } => quote! {
153 write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
154 },
155 _ => quote! {
156 write!(f, "{}", value)
157 },
158 }
159 }
160
161 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
162 match self {
163 Self::Primitive(PrimitiveType::Float) => false,
164 Self::Primitive(_) => true,
165 Self::Enum { inner, .. } => inner.is_comparable(resolved),
166 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
167 .models
168 .get(ty_name)
169 .map(|m| matches!(m, Model::Newtype(_)))
170 .unwrap_or_default(),
171 }
172 }
173}
174
175#[derive(Debug, Clone, PartialEq, Eq)]
176pub enum EnumVariantValue {
177 Repr(u32),
178 String { rename: Option<String> },
179 Tuple(Vec<EnumVariantTupleValue>),
180}
181
182impl Default for EnumVariantValue {
183 fn default() -> Self {
184 Self::String { rename: None }
185 }
186}
187
188impl EnumVariantValue {
189 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
190 match self {
191 Self::Repr(_) | Self::String { .. } => true,
192 Self::Tuple(val) => {
193 val.len() == 1
194 && val
195 .iter()
196 .next()
197 .map(|v| v.is_display(resolved))
198 .unwrap_or_default()
199 }
200 }
201 }
202
203 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
204 match self {
205 Self::Repr(_) | Self::String { .. } => true,
206 Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
207 }
208 }
209
210 pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
211 let variant = format_ident!("{name}");
212
213 match self {
214 Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
215 Self::String { rename } => {
216 let name = rename.as_deref().unwrap_or(name);
217 Some(quote! { Self::#variant => write!(f, #name) })
218 }
219 Self::Tuple(values) if values.len() == 1 => {
220 let rhs = values.first().unwrap().codegen_display();
221 Some(quote! { Self::#variant(value) => #rhs })
222 }
223 _ => None,
224 }
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq, Default)]
229pub struct EnumVariant {
230 pub name: String,
231 pub description: Option<String>,
232 pub value: EnumVariantValue,
233}
234
235pub struct EnumNamespace<'e> {
236 r#enum: &'e Enum,
237 ident: Option<Ident>,
238 elements: Vec<TokenStream>,
239 top_level_elements: Vec<TokenStream>,
240}
241
242impl EnumNamespace<'_> {
243 pub fn get_ident(&mut self) -> Ident {
244 self.ident
245 .get_or_insert_with(|| {
246 let name = self.r#enum.name.to_snake_case();
247 format_ident!("{name}")
248 })
249 .clone()
250 }
251
252 pub fn push_element(&mut self, el: TokenStream) {
253 self.elements.push(el);
254 }
255
256 pub fn push_top_level(&mut self, el: TokenStream) {
257 self.top_level_elements.push(el);
258 }
259
260 pub fn codegen(mut self) -> Option<TokenStream> {
261 if self.elements.is_empty() && self.top_level_elements.is_empty() {
262 None
263 } else {
264 let top_level = &self.top_level_elements;
265 let mut output = quote! {
266 #(#top_level)*
267 };
268
269 if !self.elements.is_empty() {
270 let ident = self.get_ident();
271 let elements = self.elements;
272 output.extend(quote! {
273 pub mod #ident {
274 #(#elements)*
275 }
276 });
277 }
278
279 Some(output)
280 }
281 }
282}
283
284impl EnumVariant {
285 pub fn codegen(
286 &self,
287 ns: &mut EnumNamespace,
288 resolved: &ResolvedSchema,
289 ) -> Option<TokenStream> {
290 let doc = self.description.as_ref().map(|d| {
291 quote! {
292 #[doc = #d]
293 }
294 });
295
296 let name = format_ident!("{}", self.name);
297
298 match &self.value {
299 EnumVariantValue::Repr(repr) => Some(quote! {
300 #doc
301 #name = #repr
302 }),
303 EnumVariantValue::String { rename } => {
304 let serde_attr = rename.as_ref().map(|r| {
305 quote! {
306 #[serde(rename = #r)]
307 }
308 });
309
310 Some(quote! {
311 #doc
312 #serde_attr
313 #name
314 })
315 }
316 EnumVariantValue::Tuple(values) => {
317 let mut val_tys = Vec::with_capacity(values.len());
318
319 if let [value] = values.as_slice() {
320 let enum_name = format_ident!("{}", ns.r#enum.name);
321 let ty_name = value.type_name(ns);
322
323 ns.push_top_level(quote! {
324 impl From<#ty_name> for #enum_name {
325 fn from(value: #ty_name) -> Self {
326 Self::#name(value)
327 }
328 }
329 });
330 }
331
332 for value in values {
333 let ty_name = value.type_name(ns);
334
335 if let EnumVariantTupleValue::Enum { inner, .. } = &value {
336 ns.push_element(inner.codegen(resolved)?);
337 }
338
339 val_tys.push(ty_name);
340 }
341
342 Some(quote! {
343 #name(#(#val_tys),*)
344 })
345 }
346 }
347 }
348
349 pub fn codegen_display(&self) -> Option<TokenStream> {
350 self.value.codegen_display(&self.name)
351 }
352}
353
354#[derive(Debug, Clone, PartialEq, Eq, Default)]
355pub struct Enum {
356 pub name: String,
357 pub description: Option<String>,
358 pub repr: Option<EnumRepr>,
359 pub copy: bool,
360 pub untagged: bool,
361 pub variants: Vec<EnumVariant>,
362}
363
364impl Enum {
365 pub fn from_schema(
366 name: &str,
367 schema: &OpenApiType,
368 warnings: WarningReporter,
369 ) -> Option<Self> {
370 let mut result = Enum {
371 name: name.to_owned(),
372 description: schema.description.as_deref().map(ToOwned::to_owned),
373 copy: true,
374 ..Default::default()
375 };
376
377 match &schema.r#enum {
378 Some(OpenApiVariants::Int(int_variants)) => {
379 result.repr = Some(EnumRepr::U32);
380 result.variants = int_variants
381 .iter()
382 .copied()
383 .map(|i| EnumVariant {
384 name: format!("Variant{i}"),
385 value: EnumVariantValue::Repr(i as u32),
386 ..Default::default()
387 })
388 .collect();
389 }
390 Some(OpenApiVariants::Str(str_variants)) => {
391 result.variants = str_variants
392 .iter()
393 .copied()
394 .map(|s| {
395 let transformed = s.replace('&', "And").to_upper_camel_case();
396 EnumVariant {
397 value: EnumVariantValue::String {
398 rename: (transformed != s).then(|| s.to_owned()),
399 },
400 name: transformed,
401 ..Default::default()
402 }
403 })
404 .collect();
405
406 let mut visited = HashSet::with_capacity(result.variants.len());
407 let mut new_variants = Vec::with_capacity(result.variants.len());
408 for variant in result.variants {
409 if visited.contains(&variant.name) {
410 warnings.push(format!("Duplicate case `{}`", variant.name));
411 } else {
412 visited.insert(variant.name.clone());
413 new_variants.push(variant);
414 }
415 }
416
417 result.variants = new_variants;
418 }
419 None => return None,
420 }
421
422 Some(result)
423 }
424
425 pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
426 let mut result = Self {
427 name: name.to_owned(),
428 copy: true,
429 ..Default::default()
430 };
431
432 for var in schema.r#enum.as_ref()? {
433 let transformed = var.to_upper_camel_case();
434 result.variants.push(EnumVariant {
435 value: EnumVariantValue::String {
436 rename: (transformed != *var).then(|| transformed.clone()),
437 },
438 name: transformed,
439 ..Default::default()
440 });
441 }
442
443 Some(result)
444 }
445
446 pub fn from_one_of(
447 name: &str,
448 schemas: &[OpenApiType],
449 warnings: WarningReporter,
450 ) -> Option<Self> {
451 let mut result = Self {
452 name: name.to_owned(),
453 untagged: true,
454 ..Default::default()
455 };
456
457 for schema in schemas {
458 let value =
459 EnumVariantTupleValue::from_schema(name, schema, warnings.child(name.to_owned()))?;
460 let name = value.name();
461
462 result.variants.push(EnumVariant {
463 name,
464 value: EnumVariantValue::Tuple(vec![value]),
465 ..Default::default()
466 });
467 }
468
469 let shared: Vec<_> = result
471 .variants
472 .iter_mut()
473 .filter(|v| v.name == "Variant")
474 .collect();
475 if shared.len() >= 2 {
476 for (idx, variant) in shared.into_iter().enumerate() {
477 let label = idx + 1;
478 variant.name = format!("Variant{label}");
479 if let EnumVariantValue::Tuple(values) = &mut variant.value {
480 if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
481 {
482 inner.name.push_str(&label.to_string());
483 name.push_str(&label.to_string());
484 }
485 }
486 }
487 }
488 Some(result)
489 }
490
491 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
492 self.variants.iter().all(|v| v.value.is_display(resolved))
493 }
494
495 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
496 self.variants
497 .iter()
498 .all(|v| v.value.is_comparable(resolved))
499 }
500
501 pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
502 let repr = self.repr.map(|r| match r {
503 EnumRepr::U8 => quote! { #[repr(u8)] },
504 EnumRepr::U32 => quote! { #[repr(u32)] },
505 });
506 let name = format_ident!("{}", self.name);
507 let desc = self.description.as_ref().map(|d| {
508 quote! {
509 #[doc = #d]
510 }
511 });
512
513 let mut ns = EnumNamespace {
514 r#enum: self,
515 ident: None,
516 elements: Default::default(),
517 top_level_elements: Default::default(),
518 };
519
520 let is_display = self.is_display(resolved);
521
522 let mut display = Vec::with_capacity(self.variants.len());
523 let mut variants = Vec::with_capacity(self.variants.len());
524 for variant in &self.variants {
525 variants.push(variant.codegen(&mut ns, resolved)?);
526
527 if is_display {
528 display.push(variant.codegen_display()?);
529 }
530 }
531
532 let mut derives = vec![];
533
534 if self.repr.is_some() {
535 derives.push(quote! { serde_repr::Deserialize_repr });
536 } else {
537 derives.push(quote! { serde::Deserialize });
538 }
539
540 if self.copy {
541 derives.push(quote! { Copy });
542 }
543
544 if self.is_comparable(resolved) {
545 derives.push(quote! { Eq, Hash });
546 }
547
548 let serde_attr = self.untagged.then(|| {
549 quote! {
550 #[serde(untagged)]
551 }
552 });
553
554 let display = is_display.then(|| {
555 quote! {
556 impl std::fmt::Display for #name {
557 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
558 match self {
559 #(#display),*
560 }
561 }
562 }
563 }
564 });
565
566 let module = ns.codegen();
567
568 Some(quote! {
569 #desc
570 #[derive(Debug, Clone, PartialEq, #(#derives),*)]
571 #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
572 #serde_attr
573 #repr
574 pub enum #name {
575 #(#variants),*
576 }
577 #display
578
579 #module
580 })
581 }
582}
583
584#[cfg(test)]
585mod test {
586 use super::*;
587
588 use crate::openapi::schema::test::get_schema;
589
590 #[test]
591 fn is_display() {
592 let schema = get_schema();
593 let resolved = ResolvedSchema::from_open_api(&schema);
594
595 let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
596 assert!(torn_selection_name.is_display(&resolved));
597 }
598}