1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::Ident;
5
6use crate::openapi::{
7 parameter::OpenApiParameterSchema,
8 r#type::{OpenApiType, OpenApiVariants},
9};
10
11use super::{object::PrimitiveType, Model, ResolvedSchema};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EnumRepr {
15 U8,
16 U32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum EnumVariantTupleValue {
21 Ref { ty_name: String },
22 ArrayOfRefs { ty_name: String },
23 Primitive(PrimitiveType),
24 Enum { name: String, inner: Enum },
25}
26
27impl EnumVariantTupleValue {
28 pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
29 match schema {
30 OpenApiType {
31 ref_path: Some(path),
32 ..
33 } => Some(Self::Ref {
34 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
35 }),
36 OpenApiType {
37 r#type: Some("array"),
38 items: Some(items),
39 ..
40 } => {
41 let OpenApiType {
42 ref_path: Some(path),
43 ..
44 } = items.as_ref()
45 else {
46 return None;
47 };
48 Some(Self::ArrayOfRefs {
49 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
50 })
51 }
52 OpenApiType {
53 r#type: Some("string"),
54 format: None,
55 r#enum: None,
56 ..
57 } => Some(Self::Primitive(PrimitiveType::String)),
58 OpenApiType {
59 r#type: Some("string"),
60 format: None,
61 r#enum: Some(_),
62 ..
63 } => {
64 let name = format!("{name}Variant");
65 Some(Self::Enum {
66 inner: Enum::from_schema(&name, schema)?,
67 name,
68 })
69 }
70 OpenApiType {
71 r#type: Some("integer"),
72 format: Some("int64"),
73 ..
74 } => Some(Self::Primitive(PrimitiveType::I64)),
75 OpenApiType {
76 r#type: Some("integer"),
77 format: Some("int32"),
78 ..
79 } => Some(Self::Primitive(PrimitiveType::I32)),
80 _ => None,
81 }
82 }
83
84 pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
85 match self {
86 Self::Ref { ty_name } => {
87 let ty = format_ident!("{ty_name}");
88 quote! { crate::models::#ty }
89 }
90 Self::ArrayOfRefs { ty_name } => {
91 let ty = format_ident!("{ty_name}");
92 quote! { Vec<crate::models::#ty> }
93 }
94 Self::Primitive(PrimitiveType::I64) => quote! { i64 },
95 Self::Primitive(PrimitiveType::I32) => quote! { i32 },
96 Self::Primitive(PrimitiveType::Float) => quote! { f32 },
97 Self::Primitive(PrimitiveType::String) => quote! { String },
98 Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
99 Self::Primitive(PrimitiveType::Bool) => quote! { bool },
100 Self::Enum { name, .. } => {
101 let path = ns.get_ident();
102 let ty_name = format_ident!("{name}");
103 quote! {
104 #path::#ty_name,
105 }
106 }
107 }
108 }
109
110 pub fn name(&self) -> String {
111 match self {
112 Self::Ref { ty_name } => ty_name.clone(),
113 Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
114 Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
115 Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
116 Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
117 Self::Primitive(PrimitiveType::String) => "String".to_owned(),
118 Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
119 Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
120 Self::Enum { .. } => "Variant".to_owned(),
121 }
122 }
123
124 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
125 match self {
126 Self::Primitive(_) => true,
127 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
128 .models
129 .get(ty_name)
130 .map(|f| f.is_display(resolved))
131 .unwrap_or_default(),
132 Self::Enum { inner, .. } => inner.is_display(resolved),
133 }
134 }
135
136 pub fn codegen_display(&self) -> TokenStream {
137 match self {
138 Self::ArrayOfRefs { .. } => quote! {
139 write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
140 },
141 _ => quote! {
142 write!(f, "{}", value)
143 },
144 }
145 }
146
147 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
148 match self {
149 Self::Primitive(_) => true,
150 Self::Enum { inner, .. } => inner.is_comparable(resolved),
151 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
152 .models
153 .get(ty_name)
154 .map(|m| matches!(m, Model::Newtype(_)))
155 .unwrap_or_default(),
156 }
157 }
158}
159
160#[derive(Debug, Clone, PartialEq, Eq)]
161pub enum EnumVariantValue {
162 Repr(u32),
163 String { rename: Option<String> },
164 Tuple(Vec<EnumVariantTupleValue>),
165}
166
167impl Default for EnumVariantValue {
168 fn default() -> Self {
169 Self::String { rename: None }
170 }
171}
172
173impl EnumVariantValue {
174 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
175 match self {
176 Self::Repr(_) | Self::String { .. } => true,
177 Self::Tuple(val) => {
178 val.len() == 1
179 && val
180 .iter()
181 .next()
182 .map(|v| v.is_display(resolved))
183 .unwrap_or_default()
184 }
185 }
186 }
187
188 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
189 match self {
190 Self::Repr(_) | Self::String { .. } => true,
191 Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
192 }
193 }
194
195 pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
196 let variant = format_ident!("{name}");
197
198 match self {
199 Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
200 Self::String { rename } => {
201 let name = rename.as_deref().unwrap_or(name);
202 Some(quote! { Self::#variant => write!(f, #name) })
203 }
204 Self::Tuple(values) if values.len() == 1 => {
205 let rhs = values.first().unwrap().codegen_display();
206 Some(quote! { Self::#variant(value) => #rhs })
207 }
208 _ => None,
209 }
210 }
211}
212
213#[derive(Debug, Clone, PartialEq, Eq, Default)]
214pub struct EnumVariant {
215 pub name: String,
216 pub description: Option<String>,
217 pub value: EnumVariantValue,
218}
219
220pub struct EnumNamespace<'e> {
221 r#enum: &'e Enum,
222 ident: Option<Ident>,
223 elements: Vec<TokenStream>,
224 top_level_elements: Vec<TokenStream>,
225}
226
227impl EnumNamespace<'_> {
228 pub fn get_ident(&mut self) -> Ident {
229 self.ident
230 .get_or_insert_with(|| {
231 let name = self.r#enum.name.to_snake_case();
232 format_ident!("{name}")
233 })
234 .clone()
235 }
236
237 pub fn push_element(&mut self, el: TokenStream) {
238 self.elements.push(el);
239 }
240
241 pub fn push_top_level(&mut self, el: TokenStream) {
242 self.top_level_elements.push(el);
243 }
244
245 pub fn codegen(mut self) -> Option<TokenStream> {
246 if self.elements.is_empty() && self.top_level_elements.is_empty() {
247 None
248 } else {
249 let top_level = &self.top_level_elements;
250 let mut output = quote! {
251 #(#top_level)*
252 };
253
254 if !self.elements.is_empty() {
255 let ident = self.get_ident();
256 let elements = self.elements;
257 output.extend(quote! {
258 pub mod #ident {
259 #(#elements)*
260 }
261 });
262 }
263
264 Some(output)
265 }
266 }
267}
268
269impl EnumVariant {
270 pub fn codegen(
271 &self,
272 ns: &mut EnumNamespace,
273 resolved: &ResolvedSchema,
274 ) -> Option<TokenStream> {
275 let doc = self.description.as_ref().map(|d| {
276 quote! {
277 #[doc = #d]
278 }
279 });
280
281 let name = format_ident!("{}", self.name);
282
283 match &self.value {
284 EnumVariantValue::Repr(repr) => Some(quote! {
285 #doc
286 #name = #repr
287 }),
288 EnumVariantValue::String { rename } => {
289 let serde_attr = rename.as_ref().map(|r| {
290 quote! {
291 #[serde(rename = #r)]
292 }
293 });
294
295 Some(quote! {
296 #doc
297 #serde_attr
298 #name
299 })
300 }
301 EnumVariantValue::Tuple(values) => {
302 let mut val_tys = Vec::with_capacity(values.len());
303
304 if let [value] = values.as_slice() {
305 let enum_name = format_ident!("{}", ns.r#enum.name);
306 let ty_name = value.type_name(ns);
307
308 ns.push_top_level(quote! {
309 impl From<#ty_name> for #enum_name {
310 fn from(value: #ty_name) -> Self {
311 Self::#name(value)
312 }
313 }
314 });
315 }
316
317 for value in values {
318 let ty_name = value.type_name(ns);
319
320 if let EnumVariantTupleValue::Enum { inner, .. } = &value {
321 ns.push_element(inner.codegen(resolved)?);
322 }
323
324 val_tys.push(ty_name);
325 }
326
327 Some(quote! {
328 #name(#(#val_tys),*)
329 })
330 }
331 }
332 }
333
334 pub fn codegen_display(&self) -> Option<TokenStream> {
335 self.value.codegen_display(&self.name)
336 }
337}
338
339#[derive(Debug, Clone, PartialEq, Eq, Default)]
340pub struct Enum {
341 pub name: String,
342 pub description: Option<String>,
343 pub repr: Option<EnumRepr>,
344 pub copy: bool,
345 pub untagged: bool,
346 pub variants: Vec<EnumVariant>,
347}
348
349impl Enum {
350 pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
351 let mut result = Enum {
352 name: name.to_owned(),
353 description: schema.description.as_deref().map(ToOwned::to_owned),
354 copy: true,
355 ..Default::default()
356 };
357
358 match &schema.r#enum {
359 Some(OpenApiVariants::Int(int_variants)) => {
360 result.repr = Some(EnumRepr::U32);
361 result.variants = int_variants
362 .iter()
363 .copied()
364 .map(|i| EnumVariant {
365 name: format!("Variant{i}"),
366 value: EnumVariantValue::Repr(i as u32),
367 ..Default::default()
368 })
369 .collect();
370 }
371 Some(OpenApiVariants::Str(str_variants)) => {
372 result.variants = str_variants
373 .iter()
374 .copied()
375 .map(|s| {
376 let transformed = s.replace('&', "And").to_upper_camel_case();
377 EnumVariant {
378 value: EnumVariantValue::String {
379 rename: (transformed != s).then(|| s.to_owned()),
380 },
381 name: transformed,
382 ..Default::default()
383 }
384 })
385 .collect();
386 }
387 None => return None,
388 }
389
390 Some(result)
391 }
392
393 pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
394 let mut result = Self {
395 name: name.to_owned(),
396 copy: true,
397 ..Default::default()
398 };
399
400 for var in schema.r#enum.as_ref()? {
401 let transformed = var.to_upper_camel_case();
402 result.variants.push(EnumVariant {
403 value: EnumVariantValue::String {
404 rename: (transformed != *var).then(|| transformed.clone()),
405 },
406 name: transformed,
407 ..Default::default()
408 });
409 }
410
411 Some(result)
412 }
413
414 pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
415 let mut result = Self {
416 name: name.to_owned(),
417 untagged: true,
418 ..Default::default()
419 };
420
421 for schema in schemas {
422 let value = EnumVariantTupleValue::from_schema(name, schema)?;
423 let name = value.name();
424
425 result.variants.push(EnumVariant {
426 name,
427 value: EnumVariantValue::Tuple(vec![value]),
428 ..Default::default()
429 });
430 }
431
432 let shared: Vec<_> = result
434 .variants
435 .iter_mut()
436 .filter(|v| v.name == "Variant")
437 .collect();
438 if shared.len() >= 2 {
439 for (idx, variant) in shared.into_iter().enumerate() {
440 let label = idx + 1;
441 variant.name = format!("Variant{}", label);
442 if let EnumVariantValue::Tuple(values) = &mut variant.value {
443 if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
444 {
445 inner.name.push_str(&label.to_string());
446 name.push_str(&label.to_string());
447 }
448 }
449 }
450 }
451 Some(result)
452 }
453
454 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
455 self.variants.iter().all(|v| v.value.is_display(resolved))
456 }
457
458 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
459 self.variants
460 .iter()
461 .all(|v| v.value.is_comparable(resolved))
462 }
463
464 pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
465 let repr = self.repr.map(|r| match r {
466 EnumRepr::U8 => quote! { #[repr(u8)] },
467 EnumRepr::U32 => quote! { #[repr(u32)] },
468 });
469 let name = format_ident!("{}", self.name);
470 let desc = self.description.as_ref().map(|d| {
471 quote! {
472 #repr
473 #[doc = #d]
474 }
475 });
476
477 let mut ns = EnumNamespace {
478 r#enum: self,
479 ident: None,
480 elements: Default::default(),
481 top_level_elements: Default::default(),
482 };
483
484 let is_display = self.is_display(resolved);
485
486 let mut display = Vec::with_capacity(self.variants.len());
487 let mut variants = Vec::with_capacity(self.variants.len());
488 for variant in &self.variants {
489 variants.push(variant.codegen(&mut ns, resolved)?);
490
491 if is_display {
492 display.push(variant.codegen_display()?);
493 }
494 }
495
496 let mut derives = vec![];
497
498 if self.repr.is_some() {
499 derives.push(quote! { serde_repr::Deserialize_repr });
500 } else {
501 derives.push(quote! { serde::Deserialize });
502 }
503
504 if self.copy {
505 derives.push(quote! { Copy });
506 }
507
508 if self.is_comparable(resolved) {
509 derives.push(quote! { Eq, Hash });
510 }
511
512 let serde_attr = self.untagged.then(|| {
513 quote! {
514 #[serde(untagged)]
515 }
516 });
517
518 let display = is_display.then(|| {
519 quote! {
520 impl std::fmt::Display for #name {
521 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
522 match self {
523 #(#display),*
524 }
525 }
526 }
527 }
528 });
529
530 let module = ns.codegen();
531
532 Some(quote! {
533 #desc
534 #[derive(Debug, Clone, PartialEq, #(#derives),*)]
535 #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
536 #serde_attr
537 pub enum #name {
538 #(#variants),*
539 }
540 #display
541
542 #module
543 })
544 }
545}
546
547#[cfg(test)]
548mod test {
549 use super::*;
550
551 use crate::openapi::schema::test::get_schema;
552
553 #[test]
554 fn is_display() {
555 let schema = get_schema();
556 let resolved = ResolvedSchema::from_open_api(&schema);
557
558 let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
559 assert!(torn_selection_name.is_display(&resolved));
560 }
561}