1use heck::{ToSnakeCase, ToUpperCamelCase};
2use proc_macro2::TokenStream;
3use quote::{format_ident, quote};
4use syn::Ident;
5
6use crate::openapi::{
7 parameter::OpenApiParameterSchema,
8 r#type::{OpenApiType, OpenApiVariants},
9};
10
11use super::{object::PrimitiveType, ResolvedSchema};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum EnumRepr {
15 U8,
16 U32,
17}
18
19#[derive(Debug, Clone, PartialEq, Eq)]
20pub enum EnumVariantTupleValue {
21 Ref { ty_name: String },
22 ArrayOfRefs { ty_name: String },
23 Primitive(PrimitiveType),
24 Enum { name: String, inner: Enum },
25}
26
27impl EnumVariantTupleValue {
28 pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
29 match schema {
30 OpenApiType {
31 ref_path: Some(path),
32 ..
33 } => Some(Self::Ref {
34 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
35 }),
36 OpenApiType {
37 r#type: Some("array"),
38 items: Some(items),
39 ..
40 } => {
41 let OpenApiType {
42 ref_path: Some(path),
43 ..
44 } = items.as_ref()
45 else {
46 return None;
47 };
48 Some(Self::ArrayOfRefs {
49 ty_name: path.strip_prefix("#/components/schemas/")?.to_owned(),
50 })
51 }
52 OpenApiType {
53 r#type: Some("string"),
54 format: None,
55 r#enum: None,
56 ..
57 } => Some(Self::Primitive(PrimitiveType::String)),
58 OpenApiType {
59 r#type: Some("string"),
60 format: None,
61 r#enum: Some(_),
62 ..
63 } => {
64 let name = format!("{name}Variant");
65 Some(Self::Enum {
66 inner: Enum::from_schema(&name, schema)?,
67 name,
68 })
69 }
70 OpenApiType {
71 r#type: Some("integer"),
72 format: Some("int64"),
73 ..
74 } => Some(Self::Primitive(PrimitiveType::I64)),
75 OpenApiType {
76 r#type: Some("integer"),
77 format: Some("int32"),
78 ..
79 } => Some(Self::Primitive(PrimitiveType::I32)),
80 _ => None,
81 }
82 }
83
84 pub fn type_name(&self, ns: &mut EnumNamespace) -> TokenStream {
85 match self {
86 Self::Ref { ty_name } => {
87 let ty = format_ident!("{ty_name}");
88 quote! { crate::models::#ty }
89 }
90 Self::ArrayOfRefs { ty_name } => {
91 let ty = format_ident!("{ty_name}");
92 quote! { Vec<crate::models::#ty> }
93 }
94 Self::Primitive(PrimitiveType::I64) => quote! { i64 },
95 Self::Primitive(PrimitiveType::I32) => quote! { i32 },
96 Self::Primitive(PrimitiveType::Float) => quote! { f32 },
97 Self::Primitive(PrimitiveType::String) => quote! { String },
98 Self::Primitive(PrimitiveType::DateTime) => quote! { chrono::DateTime<chrono::Utc> },
99 Self::Primitive(PrimitiveType::Bool) => quote! { bool },
100 Self::Enum { name, .. } => {
101 let path = ns.get_ident();
102 let ty_name = format_ident!("{name}");
103 quote! {
104 #path::#ty_name,
105 }
106 }
107 }
108 }
109
110 pub fn name(&self) -> String {
111 match self {
112 Self::Ref { ty_name } => ty_name.clone(),
113 Self::ArrayOfRefs { ty_name } => format!("{ty_name}s"),
114 Self::Primitive(PrimitiveType::I64) => "I64".to_owned(),
115 Self::Primitive(PrimitiveType::I32) => "I32".to_owned(),
116 Self::Primitive(PrimitiveType::Float) => "Float".to_owned(),
117 Self::Primitive(PrimitiveType::String) => "String".to_owned(),
118 Self::Primitive(PrimitiveType::DateTime) => "DateTime".to_owned(),
119 Self::Primitive(PrimitiveType::Bool) => "Bool".to_owned(),
120 Self::Enum { .. } => "Variant".to_owned(),
121 }
122 }
123
124 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
125 match self {
126 Self::Primitive(_) => true,
127 Self::Ref { ty_name } | Self::ArrayOfRefs { ty_name } => resolved
128 .models
129 .get(ty_name)
130 .map(|f| f.is_display(resolved))
131 .unwrap_or_default(),
132 Self::Enum { inner, .. } => inner.is_display(resolved),
133 }
134 }
135
136 pub fn codegen_display(&self) -> TokenStream {
137 match self {
138 Self::ArrayOfRefs { .. } => quote! {
139 write!(f, "{}", value.iter().map(ToString::to_string).collect::<Vec<_>>().join(","))
140 },
141 _ => quote! {
142 write!(f, "{}", value)
143 },
144 }
145 }
146}
147
148#[derive(Debug, Clone, PartialEq, Eq)]
149pub enum EnumVariantValue {
150 Repr(u32),
151 String { rename: Option<String> },
152 Tuple(Vec<EnumVariantTupleValue>),
153}
154
155impl Default for EnumVariantValue {
156 fn default() -> Self {
157 Self::String { rename: None }
158 }
159}
160
161impl EnumVariantValue {
162 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
163 match self {
164 Self::Repr(_) | Self::String { .. } => true,
165 Self::Tuple(val) => {
166 val.len() == 1
167 && val
168 .iter()
169 .next()
170 .map(|v| v.is_display(resolved))
171 .unwrap_or_default()
172 }
173 }
174 }
175
176 pub fn codegen_display(&self, name: &str) -> Option<TokenStream> {
177 let variant = format_ident!("{name}");
178
179 match self {
180 Self::Repr(i) => Some(quote! { Self::#variant => write!(f, "{}", #i) }),
181 Self::String { rename } => {
182 let name = rename.as_deref().unwrap_or(name);
183 Some(quote! { Self::#variant => write!(f, #name) })
184 }
185 Self::Tuple(values) if values.len() == 1 => {
186 let rhs = values.first().unwrap().codegen_display();
187 Some(quote! { Self::#variant(value) => #rhs })
188 }
189 _ => None,
190 }
191 }
192}
193
194#[derive(Debug, Clone, PartialEq, Eq, Default)]
195pub struct EnumVariant {
196 pub name: String,
197 pub description: Option<String>,
198 pub value: EnumVariantValue,
199}
200
201pub struct EnumNamespace<'e> {
202 r#enum: &'e Enum,
203 ident: Option<Ident>,
204 elements: Vec<TokenStream>,
205 top_level_elements: Vec<TokenStream>,
206}
207
208impl EnumNamespace<'_> {
209 pub fn get_ident(&mut self) -> Ident {
210 self.ident
211 .get_or_insert_with(|| {
212 let name = self.r#enum.name.to_snake_case();
213 format_ident!("{name}")
214 })
215 .clone()
216 }
217
218 pub fn push_element(&mut self, el: TokenStream) {
219 self.elements.push(el);
220 }
221
222 pub fn push_top_level(&mut self, el: TokenStream) {
223 self.top_level_elements.push(el);
224 }
225
226 pub fn codegen(mut self) -> Option<TokenStream> {
227 if self.elements.is_empty() && self.top_level_elements.is_empty() {
228 None
229 } else {
230 let top_level = &self.top_level_elements;
231 let mut output = quote! {
232 #(#top_level)*
233 };
234
235 if !self.elements.is_empty() {
236 let ident = self.get_ident();
237 let elements = self.elements;
238 output.extend(quote! {
239 pub mod #ident {
240 #(#elements)*
241 }
242 });
243 }
244
245 Some(output)
246 }
247 }
248}
249
250impl EnumVariant {
251 pub fn codegen(
252 &self,
253 ns: &mut EnumNamespace,
254 resolved: &ResolvedSchema,
255 ) -> Option<TokenStream> {
256 let doc = self.description.as_ref().map(|d| {
257 quote! {
258 #[doc = #d]
259 }
260 });
261
262 let name = format_ident!("{}", self.name);
263
264 match &self.value {
265 EnumVariantValue::Repr(repr) => Some(quote! {
266 #doc
267 #name = #repr
268 }),
269 EnumVariantValue::String { rename } => {
270 let serde_attr = rename.as_ref().map(|r| {
271 quote! {
272 #[serde(rename = #r)]
273 }
274 });
275
276 Some(quote! {
277 #doc
278 #serde_attr
279 #name
280 })
281 }
282 EnumVariantValue::Tuple(values) => {
283 let mut val_tys = Vec::with_capacity(values.len());
284
285 if let [value] = values.as_slice() {
286 let enum_name = format_ident!("{}", ns.r#enum.name);
287 let ty_name = value.type_name(ns);
288
289 ns.push_top_level(quote! {
290 impl From<#ty_name> for #enum_name {
291 fn from(value: #ty_name) -> Self {
292 Self::#name(value)
293 }
294 }
295 });
296 }
297
298 for value in values {
299 let ty_name = value.type_name(ns);
300
301 if let EnumVariantTupleValue::Enum { inner, .. } = &value {
302 ns.push_element(inner.codegen(resolved)?);
303 }
304
305 val_tys.push(ty_name);
306 }
307
308 Some(quote! {
309 #name(#(#val_tys),*)
310 })
311 }
312 }
313 }
314
315 pub fn codegen_display(&self) -> Option<TokenStream> {
316 self.value.codegen_display(&self.name)
317 }
318}
319
320#[derive(Debug, Clone, PartialEq, Eq, Default)]
321pub struct Enum {
322 pub name: String,
323 pub description: Option<String>,
324 pub repr: Option<EnumRepr>,
325 pub copy: bool,
326 pub untagged: bool,
327 pub variants: Vec<EnumVariant>,
328}
329
330impl Enum {
331 pub fn from_schema(name: &str, schema: &OpenApiType) -> Option<Self> {
332 let mut result = Enum {
333 name: name.to_owned(),
334 description: schema.description.as_deref().map(ToOwned::to_owned),
335 copy: true,
336 ..Default::default()
337 };
338
339 match &schema.r#enum {
340 Some(OpenApiVariants::Int(int_variants)) => {
341 result.repr = Some(EnumRepr::U32);
342 result.variants = int_variants
343 .iter()
344 .copied()
345 .map(|i| EnumVariant {
346 name: format!("Variant{i}"),
347 value: EnumVariantValue::Repr(i as u32),
348 ..Default::default()
349 })
350 .collect();
351 }
352 Some(OpenApiVariants::Str(str_variants)) => {
353 result.variants = str_variants
354 .iter()
355 .copied()
356 .map(|s| {
357 let transformed = s.replace('&', "And").to_upper_camel_case();
358 EnumVariant {
359 value: EnumVariantValue::String {
360 rename: (transformed != s).then(|| s.to_owned()),
361 },
362 name: transformed,
363 ..Default::default()
364 }
365 })
366 .collect();
367 }
368 None => return None,
369 }
370
371 Some(result)
372 }
373
374 pub fn from_parameter_schema(name: &str, schema: &OpenApiParameterSchema) -> Option<Self> {
375 let mut result = Self {
376 name: name.to_owned(),
377 copy: true,
378 ..Default::default()
379 };
380
381 for var in schema.r#enum.as_ref()? {
382 let transformed = var.to_upper_camel_case();
383 result.variants.push(EnumVariant {
384 value: EnumVariantValue::String {
385 rename: (transformed != *var).then(|| transformed.clone()),
386 },
387 name: transformed,
388 ..Default::default()
389 });
390 }
391
392 Some(result)
393 }
394
395 pub fn from_one_of(name: &str, schemas: &[OpenApiType]) -> Option<Self> {
396 let mut result = Self {
397 name: name.to_owned(),
398 untagged: true,
399 ..Default::default()
400 };
401
402 for schema in schemas {
403 let value = EnumVariantTupleValue::from_schema(name, schema)?;
404 let name = value.name();
405
406 result.variants.push(EnumVariant {
407 name,
408 value: EnumVariantValue::Tuple(vec![value]),
409 ..Default::default()
410 });
411 }
412
413 Some(result)
414 }
415
416 pub fn is_display(&self, resolved: &ResolvedSchema) -> bool {
417 self.variants.iter().all(|v| v.value.is_display(resolved))
418 }
419
420 pub fn codegen(&self, resolved: &ResolvedSchema) -> Option<TokenStream> {
421 let repr = self.repr.map(|r| match r {
422 EnumRepr::U8 => quote! { #[repr(u8)]},
423 EnumRepr::U32 => quote! { #[repr(u32)]},
424 });
425 let name = format_ident!("{}", self.name);
426 let desc = self.description.as_ref().map(|d| {
427 quote! {
428 #repr
429 #[doc = #d]
430 }
431 });
432
433 let mut ns = EnumNamespace {
434 r#enum: self,
435 ident: None,
436 elements: Default::default(),
437 top_level_elements: Default::default(),
438 };
439
440 let is_display = self.is_display(resolved);
441
442 let mut display = Vec::with_capacity(self.variants.len());
443 let mut variants = Vec::with_capacity(self.variants.len());
444 for variant in &self.variants {
445 variants.push(variant.codegen(&mut ns, resolved)?);
446
447 if is_display {
448 display.push(variant.codegen_display()?);
449 }
450 }
451
452 let mut derives = vec![];
453
454 if self.repr.is_some() {
455 derives.push(quote! { serde_repr::Deserialize_repr });
456 } else {
457 derives.push(quote! { serde::Deserialize });
458 }
459
460 if self.copy {
461 derives.push(quote! { Copy, Hash });
462 }
463
464 let serde_attr = self.untagged.then(|| {
465 quote! {
466 #[serde(untagged)]
467 }
468 });
469
470 let display = is_display.then(|| {
471 quote! {
472 impl std::fmt::Display for #name {
473 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
474 match self {
475 #(#display),*
476 }
477 }
478 }
479 }
480 });
481
482 let module = ns.codegen();
483
484 Some(quote! {
485 #desc
486 #[derive(Debug, Clone, PartialEq, #(#derives),*)]
487 #[cfg_attr(feature = "strum", derive(strum::EnumIs, strum::EnumTryAs))]
488 #serde_attr
489 pub enum #name {
490 #(#variants),*
491 }
492 #display
493
494 #module
495 })
496 }
497}
498
499#[cfg(test)]
500mod test {
501 use super::*;
502
503 use crate::openapi::schema::test::get_schema;
504
505 #[test]
506 fn is_display() {
507 let schema = get_schema();
508 let resolved = ResolvedSchema::from_open_api(&schema);
509
510 let torn_selection_name = resolved.models.get("TornSelectionName").unwrap();
511 assert!(torn_selection_name.is_display(&resolved));
512 }
513}