Skip to main content

unistructgen_openapi_parser/
types.rs

1//! Type conversion utilities for OpenAPI schemas
2
3use crate::error::{OpenApiError, Result};
4use openapiv3::{Schema, SchemaKind, Type};
5use unistructgen_core::{IRTypeRef, PrimitiveKind};
6
7/// Convert OpenAPI type to IR type reference
8pub fn openapi_type_to_ir(
9    schema: &Schema,
10    type_name_hint: Option<&str>,
11) -> Result<IRTypeRef> {
12    match &schema.schema_kind {
13        SchemaKind::Type(Type::String(string_type)) => {
14            // Check for format-specific types
15            let format_str = format!("{:?}", string_type.format);
16            if format_str.contains("DateTime") {
17                Ok(IRTypeRef::Primitive(PrimitiveKind::DateTime))
18            } else if format_str.contains("Uuid") {
19                Ok(IRTypeRef::Primitive(PrimitiveKind::Uuid))
20            } else if !string_type.enumeration.is_empty() {
21                // Enum type - will be handled by schema converter
22                if let Some(name) = type_name_hint {
23                    Ok(IRTypeRef::Named(to_pascal_case(name)))
24                } else {
25                    Ok(IRTypeRef::Primitive(PrimitiveKind::String))
26                }
27            } else {
28                Ok(IRTypeRef::Primitive(PrimitiveKind::String))
29            }
30        }
31
32        SchemaKind::Type(Type::Number(number_type)) => {
33            let format_str = format!("{:?}", number_type.format);
34            if format_str.contains("Float") {
35                Ok(IRTypeRef::Primitive(PrimitiveKind::F32))
36            } else {
37                Ok(IRTypeRef::Primitive(PrimitiveKind::F64))
38            }
39        }
40
41        SchemaKind::Type(Type::Integer(int_type)) => {
42            let format_str = format!("{:?}", int_type.format);
43            if format_str.contains("Int32") {
44                Ok(IRTypeRef::Primitive(PrimitiveKind::I32))
45            } else {
46                Ok(IRTypeRef::Primitive(PrimitiveKind::I64))
47            }
48        }
49
50        SchemaKind::Type(Type::Boolean(_)) => {
51            Ok(IRTypeRef::Primitive(PrimitiveKind::Bool))
52        }
53
54        SchemaKind::Type(Type::Array(array_type)) => {
55            if let Some(ref items) = array_type.items {
56                let item_type = match items {
57                    openapiv3::ReferenceOr::Reference { reference } => {
58                        // Extract type name from reference
59                        let type_name = extract_type_name_from_ref(&reference);
60                        IRTypeRef::Named(type_name)
61                    }
62                    openapiv3::ReferenceOr::Item(schema) => {
63                        openapi_type_to_ir(schema.as_ref(), type_name_hint)?
64                    }
65                };
66                Ok(IRTypeRef::Vec(Box::new(item_type)))
67            } else {
68                // Array without items - use generic JSON value
69                Ok(IRTypeRef::Vec(Box::new(IRTypeRef::Primitive(
70                    PrimitiveKind::Json,
71                ))))
72            }
73        }
74
75        SchemaKind::Type(Type::Object(obj_type)) => {
76            // Check if it's a map/dictionary
77            if obj_type.properties.is_empty() && obj_type.additional_properties.is_some() {
78                // This is a map
79                let value_type = match obj_type.additional_properties.as_ref().unwrap() {
80                    openapiv3::AdditionalProperties::Any(true) => {
81                        IRTypeRef::Primitive(PrimitiveKind::Json)
82                    }
83                    openapiv3::AdditionalProperties::Schema(schema_ref) => {
84                        match schema_ref.as_ref() {
85                            openapiv3::ReferenceOr::Reference { reference } => {
86                                IRTypeRef::Named(extract_type_name_from_ref(reference))
87                            }
88                            openapiv3::ReferenceOr::Item(schema) => {
89                                openapi_type_to_ir(schema, None)?
90                            }
91                        }
92                    }
93                    _ => IRTypeRef::Primitive(PrimitiveKind::Json),
94                };
95
96                Ok(IRTypeRef::Map(
97                    Box::new(IRTypeRef::Primitive(PrimitiveKind::String)),
98                    Box::new(value_type),
99                ))
100            } else if let Some(name) = type_name_hint {
101                // Named object type
102                Ok(IRTypeRef::Named(to_pascal_case(name)))
103            } else {
104                // Anonymous object - use JSON value
105                Ok(IRTypeRef::Primitive(PrimitiveKind::Json))
106            }
107        }
108
109        SchemaKind::OneOf { .. } | SchemaKind::AnyOf { .. } | SchemaKind::AllOf { .. } => {
110            // Schema composition - will be handled specially
111            if let Some(name) = type_name_hint {
112                Ok(IRTypeRef::Named(to_pascal_case(name)))
113            } else {
114                Ok(IRTypeRef::Primitive(PrimitiveKind::Json))
115            }
116        }
117
118        SchemaKind::Any(_) => {
119            // Untyped schema - use JSON value
120            Ok(IRTypeRef::Primitive(PrimitiveKind::Json))
121        }
122
123        SchemaKind::Not { .. } => {
124            Err(OpenApiError::unsupported_type("not schemas are not supported"))
125        }
126    }
127}
128
129/// Extract type name from OpenAPI reference
130/// Example: "#/components/schemas/User" -> "User"
131pub fn extract_type_name_from_ref(reference: &str) -> String {
132    reference
133        .rsplit('/')
134        .next()
135        .unwrap_or(reference)
136        .to_string()
137}
138
139/// Convert snake_case or kebab-case to PascalCase
140pub fn to_pascal_case(s: &str) -> String {
141    s.split(|c| c == '_' || c == '-' || c == ' ')
142        .filter(|part| !part.is_empty())
143        .map(|part| {
144            let mut chars = part.chars();
145            match chars.next() {
146                Some(first) => first.to_uppercase().chain(chars).collect(),
147                None => String::new(),
148            }
149        })
150        .collect()
151}
152
153/// Convert PascalCase or camelCase to snake_case
154pub fn to_snake_case(s: &str) -> String {
155    let mut result = String::new();
156    let chars: Vec<char> = s.chars().collect();
157
158    for (i, &c) in chars.iter().enumerate() {
159        if c.is_uppercase() {
160            // Add underscore if:
161            // 1. Not the first character AND
162            // 2. Previous char is lowercase OR
163            // 3. Next char exists and is lowercase (for handling acronyms like "APIKey" -> "api_key")
164            if i > 0 {
165                let prev_is_lower = chars.get(i - 1).map_or(false, |c| c.is_lowercase());
166                let next_is_lower = chars.get(i + 1).map_or(false, |c| c.is_lowercase());
167
168                if prev_is_lower || next_is_lower {
169                    result.push('_');
170                }
171            }
172            result.push(c.to_lowercase().next().unwrap());
173        } else {
174            result.push(c);
175        }
176    }
177
178    result
179}
180
181/// Sanitize field name to be a valid Rust identifier
182pub fn sanitize_field_name(name: &str) -> String {
183    // Convert to snake_case first
184    let snake = to_snake_case(name);
185
186    // Replace invalid characters
187    let sanitized: String = snake
188        .chars()
189        .map(|c| {
190            if c.is_alphanumeric() || c == '_' {
191                c
192            } else {
193                '_'
194            }
195        })
196        .collect();
197
198    // Check if it starts with a number
199    if sanitized.chars().next().map_or(false, |c| c.is_numeric()) {
200        format!("_{}", sanitized)
201    } else if is_rust_keyword(&sanitized) {
202        // Append underscore to Rust keywords
203        format!("{}_", sanitized)
204    } else {
205        sanitized
206    }
207}
208
209/// Check if a string is a Rust keyword
210pub fn is_rust_keyword(s: &str) -> bool {
211    matches!(
212        s,
213        "as" | "break"
214            | "const"
215            | "continue"
216            | "crate"
217            | "else"
218            | "enum"
219            | "extern"
220            | "false"
221            | "fn"
222            | "for"
223            | "if"
224            | "impl"
225            | "in"
226            | "let"
227            | "loop"
228            | "match"
229            | "mod"
230            | "move"
231            | "mut"
232            | "pub"
233            | "ref"
234            | "return"
235            | "self"
236            | "Self"
237            | "static"
238            | "struct"
239            | "super"
240            | "trait"
241            | "true"
242            | "type"
243            | "unsafe"
244            | "use"
245            | "where"
246            | "while"
247            | "async"
248            | "await"
249            | "dyn"
250            | "abstract"
251            | "become"
252            | "box"
253            | "do"
254            | "final"
255            | "macro"
256            | "override"
257            | "priv"
258            | "typeof"
259            | "unsized"
260            | "virtual"
261            | "yield"
262    )
263}
264
265#[cfg(test)]
266mod tests {
267    use super::*;
268
269    #[test]
270    fn test_to_pascal_case() {
271        assert_eq!(to_pascal_case("user_profile"), "UserProfile");
272        assert_eq!(to_pascal_case("api-key"), "ApiKey");
273        assert_eq!(to_pascal_case("simple"), "Simple");
274        assert_eq!(to_pascal_case("my_long_type_name"), "MyLongTypeName");
275    }
276
277    #[test]
278    fn test_to_snake_case() {
279        assert_eq!(to_snake_case("UserProfile"), "user_profile");
280        assert_eq!(to_snake_case("APIKey"), "api_key");
281        assert_eq!(to_snake_case("simple"), "simple");
282        assert_eq!(to_snake_case("myLongTypeName"), "my_long_type_name");
283    }
284
285    #[test]
286    fn test_sanitize_field_name() {
287        assert_eq!(sanitize_field_name("type"), "type_");
288        assert_eq!(sanitize_field_name("123field"), "_123field");
289        assert_eq!(sanitize_field_name("user-name"), "user_name");
290        assert_eq!(sanitize_field_name("valid_field"), "valid_field");
291    }
292
293    #[test]
294    fn test_extract_type_name_from_ref() {
295        assert_eq!(
296            extract_type_name_from_ref("#/components/schemas/User"),
297            "User"
298        );
299        assert_eq!(
300            extract_type_name_from_ref("#/components/schemas/ApiKey"),
301            "ApiKey"
302        );
303        assert_eq!(extract_type_name_from_ref("User"), "User");
304    }
305
306    #[test]
307    fn test_is_rust_keyword() {
308        assert!(is_rust_keyword("type"));
309        assert!(is_rust_keyword("async"));
310        assert!(is_rust_keyword("await"));
311        assert!(!is_rust_keyword("user"));
312        assert!(!is_rust_keyword("field"));
313    }
314}