1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::Ident;
5
6use crate::openapi::{
7 parameter::OpenApiParameterSchema,
8 r#type::{OpenApiType, OpenApiVariants},
9};
10
11use super::{object::PrimitiveType, Model, ResolvedSchema};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EnumRepr {
15 U8,
16 U32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum EnumVariantTupleValue {
21 Ref { ty_name: String },
22 ArrayOfRefs { ty_name: String },
23 Primitive(PrimitiveType),
24 Enum { name: String, inner: Enum },
25}
26
27impl EnumVariantTupleValue {
28 pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
29 match schema {
30 OpenApiType {
31 ref_path: Some(path),
32 ..
33 } => Some(Self::Ref {
34 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
35 }),
36 OpenApiType {
37 r#type: Some("array"),
38 items: Some(items),
39 ..
40 } => {
41 let OpenApiType {
42 ref_path: Some(path),
43 ..
44 } = items.as_ref()
45 else {
46 return None;
47 };
48 Some(Self::ArrayOfRefs {
49 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
50 })
51 }
52 OpenApiType {
53 r#type: Some("string"),
54 format: None,
55 r#enum: None,
56 ..
57 } => Some(Self::Primitive(PrimitiveType::String)),
58 OpenApiType {
59 r#type: Some("string"),
60 format: None,
61 r#enum: Some(_),
62 ..
63 } => {
64 let name = format!("{name}Variant");
65 Some(Self::Enum {
66 inner: Enum::from_schema(&name, schema)?,
67 name,
68 })
69 }
70 OpenApiType {
71 r#type: Some("integer"),
72 format: Some("int64"),
73 ..
74 } => Some(Self::Primitive(PrimitiveType::I64)),
75 OpenApiType {
76 r#type: Some("integer"),
77 format: Some("int32"),
78 ..
79 } => Some(Self::Primitive(PrimitiveType::I32)),
80 OpenApiType {
81 r#type: Some("number"),
82 format: Some("float") | None,
83 ..
84 } => Some(Self::Primitive(PrimitiveType::Float)),
85 _ => None,
86 }
87 }
88
89 pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
90 match self {
91 Self::Ref { ty_name } => {
92 let ty = format_ident!("{ty_name}");
93 quote! { crate::models::#ty }
94 }
95 Self::ArrayOfRefs { ty_name } => {
96 let ty = format_ident!("{ty_name}");
97 quote! { Vec<crate::models::#ty> }
98 }
99 Self::Primitive(PrimitiveType::I64) => quote! { i64 },
100 Self::Primitive(PrimitiveType::I32) => quote! { i32 },
101 Self::Primitive(PrimitiveType::Float) => quote! { f32 },
102 Self::Primitive(PrimitiveType::String) => quote! { String },
103 Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
104 Self::Primitive(PrimitiveType::Bool) => quote! { bool },
105 Self::Enum { name, .. } => {
106 let path = ns.get_ident();
107 let ty_name = format_ident!("{name}");
108 quote! {
109 #path::#ty_name,
110 }
111 }
112 }
113 }
114
115 pub fn name(&self) -> String {
116 match self {
117 Self::Ref { ty_name } => ty_name.clone(),
118 Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
119 Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
120 Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
121 Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
122 Self::Primitive(PrimitiveType::String) => "String".to_owned(),
123 Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
124 Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
125 Self::Enum { .. } => "Variant".to_owned(),
126 }
127 }
128
129 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
130 match self {
131 Self::Primitive(_) => true,
132 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
133 .models
134 .get(ty_name)
135 .map(|f| f.is_display(resolved))
136 .unwrap_or_default(),
137 Self::Enum { inner, .. } => inner.is_display(resolved),
138 }
139 }
140
141 pub fn codegen_display(&self) -> TokenStream {
142 match self {
143 Self::ArrayOfRefs { .. } => quote! {
144 write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
145 },
146 _ => quote! {
147 write!(f, "{}", value)
148 },
149 }
150 }
151
152 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
153 match self {
154 Self::Primitive(PrimitiveType::Float) => false,
155 Self::Primitive(_) => true,
156 Self::Enum { inner, .. } => inner.is_comparable(resolved),
157 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
158 .models
159 .get(ty_name)
160 .map(|m| matches!(m, Model::Newtype(_)))
161 .unwrap_or_default(),
162 }
163 }
164}
165
166#[derive(Debug, Clone, PartialEq, Eq)]
167pub enum EnumVariantValue {
168 Repr(u32),
169 String { rename: Option<String> },
170 Tuple(Vec<EnumVariantTupleValue>),
171}
172
173impl Default for EnumVariantValue {
174 fn default() -> Self {
175 Self::String { rename: None }
176 }
177}
178
179impl EnumVariantValue {
180 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
181 match self {
182 Self::Repr(_) | Self::String { .. } => true,
183 Self::Tuple(val) => {
184 val.len() == 1
185 && val
186 .iter()
187 .next()
188 .map(|v| v.is_display(resolved))
189 .unwrap_or_default()
190 }
191 }
192 }
193
194 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
195 match self {
196 Self::Repr(_) | Self::String { .. } => true,
197 Self::Tuple(values) => values.iter().all(|v| v.is_comparable(resolved)),
198 }
199 }
200
201 pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
202 let variant = format_ident!("{name}");
203
204 match self {
205 Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
206 Self::String { rename } => {
207 let name = rename.as_deref().unwrap_or(name);
208 Some(quote! { Self::#variant => write!(f, #name) })
209 }
210 Self::Tuple(values) if values.len() == 1 => {
211 let rhs = values.first().unwrap().codegen_display();
212 Some(quote! { Self::#variant(value) => #rhs })
213 }
214 _ => None,
215 }
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq, Default)]
220pub struct EnumVariant {
221 pub name: String,
222 pub description: Option<String>,
223 pub value: EnumVariantValue,
224}
225
226pub struct EnumNamespace<'e> {
227 r#enum: &'e Enum,
228 ident: Option<Ident>,
229 elements: Vec<TokenStream>,
230 top_level_elements: Vec<TokenStream>,
231}
232
233impl EnumNamespace<'_> {
234 pub fn get_ident(&mut self) -> Ident {
235 self.ident
236 .get_or_insert_with(|| {
237 let name = self.r#enum.name.to_snake_case();
238 format_ident!("{name}")
239 })
240 .clone()
241 }
242
243 pub fn push_element(&mut self, el: TokenStream) {
244 self.elements.push(el);
245 }
246
247 pub fn push_top_level(&mut self, el: TokenStream) {
248 self.top_level_elements.push(el);
249 }
250
251 pub fn codegen(mut self) -> Option<TokenStream> {
252 if self.elements.is_empty() && self.top_level_elements.is_empty() {
253 None
254 } else {
255 let top_level = &self.top_level_elements;
256 let mut output = quote! {
257 #(#top_level)*
258 };
259
260 if !self.elements.is_empty() {
261 let ident = self.get_ident();
262 let elements = self.elements;
263 output.extend(quote! {
264 pub mod #ident {
265 #(#elements)*
266 }
267 });
268 }
269
270 Some(output)
271 }
272 }
273}
274
275impl EnumVariant {
276 pub fn codegen(
277 &self,
278 ns: &mut EnumNamespace,
279 resolved: &ResolvedSchema,
280 ) -> Option<TokenStream> {
281 let doc = self.description.as_ref().map(|d| {
282 quote! {
283 #[doc = #d]
284 }
285 });
286
287 let name = format_ident!("{}", self.name);
288
289 match &self.value {
290 EnumVariantValue::Repr(repr) => Some(quote! {
291 #doc
292 #name = #repr
293 }),
294 EnumVariantValue::String { rename } => {
295 let serde_attr = rename.as_ref().map(|r| {
296 quote! {
297 #[serde(rename = #r)]
298 }
299 });
300
301 Some(quote! {
302 #doc
303 #serde_attr
304 #name
305 })
306 }
307 EnumVariantValue::Tuple(values) => {
308 let mut val_tys = Vec::with_capacity(values.len());
309
310 if let [value] = values.as_slice() {
311 let enum_name = format_ident!("{}", ns.r#enum.name);
312 let ty_name = value.type_name(ns);
313
314 ns.push_top_level(quote! {
315 impl From<#ty_name> for #enum_name {
316 fn from(value: #ty_name) -> Self {
317 Self::#name(value)
318 }
319 }
320 });
321 }
322
323 for value in values {
324 let ty_name = value.type_name(ns);
325
326 if let EnumVariantTupleValue::Enum { inner, .. } = &value {
327 ns.push_element(inner.codegen(resolved)?);
328 }
329
330 val_tys.push(ty_name);
331 }
332
333 Some(quote! {
334 #name(#(#val_tys),*)
335 })
336 }
337 }
338 }
339
340 pub fn codegen_display(&self) -> Option<TokenStream> {
341 self.value.codegen_display(&self.name)
342 }
343}
344
345#[derive(Debug, Clone, PartialEq, Eq, Default)]
346pub struct Enum {
347 pub name: String,
348 pub description: Option<String>,
349 pub repr: Option<EnumRepr>,
350 pub copy: bool,
351 pub untagged: bool,
352 pub variants: Vec<EnumVariant>,
353}
354
355impl Enum {
356 pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
357 let mut result = Enum {
358 name: name.to_owned(),
359 description: schema.description.as_deref().map(ToOwned::to_owned),
360 copy: true,
361 ..Default::default()
362 };
363
364 match &schema.r#enum {
365 Some(OpenApiVariants::Int(int_variants)) => {
366 result.repr = Some(EnumRepr::U32);
367 result.variants = int_variants
368 .iter()
369 .copied()
370 .map(|i| EnumVariant {
371 name: format!("Variant{i}"),
372 value: EnumVariantValue::Repr(i as u32),
373 ..Default::default()
374 })
375 .collect();
376 }
377 Some(OpenApiVariants::Str(str_variants)) => {
378 result.variants = str_variants
379 .iter()
380 .copied()
381 .map(|s| {
382 let transformed = s.replace('&', "And").to_upper_camel_case();
383 EnumVariant {
384 value: EnumVariantValue::String {
385 rename: (transformed != s).then(|| s.to_owned()),
386 },
387 name: transformed,
388 ..Default::default()
389 }
390 })
391 .collect();
392 }
393 None => return None,
394 }
395
396 Some(result)
397 }
398
399 pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
400 let mut result = Self {
401 name: name.to_owned(),
402 copy: true,
403 ..Default::default()
404 };
405
406 for var in schema.r#enum.as_ref()? {
407 let transformed = var.to_upper_camel_case();
408 result.variants.push(EnumVariant {
409 value: EnumVariantValue::String {
410 rename: (transformed != *var).then(|| transformed.clone()),
411 },
412 name: transformed,
413 ..Default::default()
414 });
415 }
416
417 Some(result)
418 }
419
420 pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
421 let mut result = Self {
422 name: name.to_owned(),
423 untagged: true,
424 ..Default::default()
425 };
426
427 for schema in schemas {
428 let value = EnumVariantTupleValue::from_schema(name, schema)?;
429 let name = value.name();
430
431 result.variants.push(EnumVariant {
432 name,
433 value: EnumVariantValue::Tuple(vec![value]),
434 ..Default::default()
435 });
436 }
437
438 let shared: Vec<_> = result
440 .variants
441 .iter_mut()
442 .filter(|v| v.name == "Variant")
443 .collect();
444 if shared.len() >= 2 {
445 for (idx, variant) in shared.into_iter().enumerate() {
446 let label = idx + 1;
447 variant.name = format!("Variant{}", label);
448 if let EnumVariantValue::Tuple(values) = &mut variant.value {
449 if let [EnumVariantTupleValue::Enum { name, inner, .. }] = values.as_mut_slice()
450 {
451 inner.name.push_str(&label.to_string());
452 name.push_str(&label.to_string());
453 }
454 }
455 }
456 }
457 Some(result)
458 }
459
460 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
461 self.variants.iter().all(|v| v.value.is_display(resolved))
462 }
463
464 pub fn is_comparable(&self, resolved: &ResolvedSchema) -> bool {
465 self.variants
466 .iter()
467 .all(|v| v.value.is_comparable(resolved))
468 }
469
470 pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
471 let repr = self.repr.map(|r| match r {
472 EnumRepr::U8 => quote! { #[repr(u8)] },
473 EnumRepr::U32 => quote! { #[repr(u32)] },
474 });
475 let name = format_ident!("{}", self.name);
476 let desc = self.description.as_ref().map(|d| {
477 quote! {
478 #repr
479 #[doc = #d]
480 }
481 });
482
483 let mut ns = EnumNamespace {
484 r#enum: self,
485 ident: None,
486 elements: Default::default(),
487 top_level_elements: Default::default(),
488 };
489
490 let is_display = self.is_display(resolved);
491
492 let mut display = Vec::with_capacity(self.variants.len());
493 let mut variants = Vec::with_capacity(self.variants.len());
494 for variant in &self.variants {
495 variants.push(variant.codegen(&mut ns, resolved)?);
496
497 if is_display {
498 display.push(variant.codegen_display()?);
499 }
500 }
501
502 let mut derives = vec![];
503
504 if self.repr.is_some() {
505 derives.push(quote! { serde_repr::Deserialize_repr });
506 } else {
507 derives.push(quote! { serde::Deserialize });
508 }
509
510 if self.copy {
511 derives.push(quote! { Copy });
512 }
513
514 if self.is_comparable(resolved) {
515 derives.push(quote! { Eq, Hash });
516 }
517
518 let serde_attr = self.untagged.then(|| {
519 quote! {
520 #[serde(untagged)]
521 }
522 });
523
524 let display = is_display.then(|| {
525 quote! {
526 impl std::fmt::Display for #name {
527 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
528 match self {
529 #(#display),*
530 }
531 }
532 }
533 }
534 });
535
536 let module = ns.codegen();
537
538 Some(quote! {
539 #desc
540 #[derive(Debug, Clone, PartialEq, #(#derives),*)]
541 #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
542 #serde_attr
543 pub enum #name {
544 #(#variants),*
545 }
546 #display
547
548 #module
549 })
550 }
551}
552
553#[cfg(test)]
554mod test {
555 use super::*;
556
557 use crate::openapi::schema::test::get_schema;
558
559 #[test]
560 fn is_display() {
561 let schema = get_schema();
562 let resolved = ResolvedSchema::from_open_api(&schema);
563
564 let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
565 assert!(torn_selection_name.is_display(&resolved));
566 }
567}