1use heck::ToUpperCamelCase;
2
3use super::RustType;
4use crate::introspect::SchemaInfo;
5
6pub fn is_builtin(udt_name: &str) -> bool {
9 matches!(
10 udt_name,
11 "bool"
12 | "int2" | "smallint" | "smallserial"
13 | "int4" | "int" | "integer" | "serial"
14 | "int8" | "bigint" | "bigserial"
15 | "float4" | "real"
16 | "float8" | "double precision"
17 | "numeric" | "decimal"
18 | "varchar" | "text" | "bpchar" | "char" | "name" | "citext"
19 | "bytea"
20 | "timestamp" | "timestamp without time zone"
21 | "timestamptz" | "timestamp with time zone"
22 | "date"
23 | "time" | "time without time zone"
24 | "timetz" | "time with time zone"
25 | "uuid"
26 | "json" | "jsonb"
27 | "inet" | "cidr"
28 | "oid"
29 )
30}
31
32pub fn map_type(udt_name: &str, schema_info: &SchemaInfo) -> RustType {
33 if let Some(inner) = udt_name.strip_prefix('_') {
35 let inner_type = map_type(inner, schema_info);
36 return inner_type.wrap_vec();
37 }
38
39 if schema_info.enums.iter().any(|e| e.name == udt_name) {
41 let name = udt_name.to_upper_camel_case();
42 return RustType::with_import(&name, &format!("use super::types::{};", name));
43 }
44
45 if schema_info.composite_types.iter().any(|c| c.name == udt_name) {
47 let name = udt_name.to_upper_camel_case();
48 return RustType::with_import(&name, &format!("use super::types::{};", name));
49 }
50
51 if let Some(domain) = schema_info.domains.iter().find(|d| d.name == udt_name) {
53 return map_type(&domain.base_type, schema_info);
55 }
56
57 match udt_name {
58 "bool" => RustType::simple("bool"),
59 "int2" | "smallint" | "smallserial" => RustType::simple("i16"),
60 "int4" | "int" | "integer" | "serial" => RustType::simple("i32"),
61 "int8" | "bigint" | "bigserial" => RustType::simple("i64"),
62 "float4" | "real" => RustType::simple("f32"),
63 "float8" | "double precision" => RustType::simple("f64"),
64 "numeric" | "decimal" => {
65 RustType::with_import("Decimal", "use rust_decimal::Decimal;")
66 }
67 "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"),
68 "bytea" => RustType::simple("Vec<u8>"),
69 "timestamp" | "timestamp without time zone" => {
70 RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;")
71 }
72 "timestamptz" | "timestamp with time zone" => {
73 RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};")
74 }
75 "date" => RustType::with_import("NaiveDate", "use chrono::NaiveDate;"),
76 "time" | "time without time zone" => {
77 RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
78 }
79 "timetz" | "time with time zone" => {
80 RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
81 }
82 "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"),
83 "json" | "jsonb" => {
84 RustType::with_import("Value", "use serde_json::Value;")
85 }
86 "inet" | "cidr" => {
87 RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;")
88 }
89 "oid" => RustType::simple("u32"),
90 _ => RustType::simple("String"), }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::introspect::{CompositeTypeInfo, DomainInfo, EnumInfo};
98
99 fn empty_schema() -> SchemaInfo {
100 SchemaInfo::default()
101 }
102
103 fn schema_with_enum(name: &str) -> SchemaInfo {
104 SchemaInfo {
105 enums: vec![EnumInfo {
106 schema_name: "public".to_string(),
107 name: name.to_string(),
108 variants: vec!["a".to_string()],
109 default_variant: None,
110 }],
111 ..Default::default()
112 }
113 }
114
115 fn schema_with_composite(name: &str) -> SchemaInfo {
116 SchemaInfo {
117 composite_types: vec![CompositeTypeInfo {
118 schema_name: "public".to_string(),
119 name: name.to_string(),
120 fields: vec![],
121 }],
122 ..Default::default()
123 }
124 }
125
126 fn schema_with_domain(name: &str, base: &str) -> SchemaInfo {
127 SchemaInfo {
128 domains: vec![DomainInfo {
129 schema_name: "public".to_string(),
130 name: name.to_string(),
131 base_type: base.to_string(),
132 }],
133 ..Default::default()
134 }
135 }
136
137 #[test]
140 fn test_bool() {
141 assert_eq!(map_type("bool", &empty_schema()).path, "bool");
142 }
143
144 #[test]
145 fn test_int2() {
146 assert_eq!(map_type("int2", &empty_schema()).path, "i16");
147 }
148
149 #[test]
150 fn test_smallint() {
151 assert_eq!(map_type("smallint", &empty_schema()).path, "i16");
152 }
153
154 #[test]
155 fn test_smallserial() {
156 assert_eq!(map_type("smallserial", &empty_schema()).path, "i16");
157 }
158
159 #[test]
160 fn test_int4() {
161 assert_eq!(map_type("int4", &empty_schema()).path, "i32");
162 }
163
164 #[test]
165 fn test_integer() {
166 assert_eq!(map_type("integer", &empty_schema()).path, "i32");
167 }
168
169 #[test]
170 fn test_serial() {
171 assert_eq!(map_type("serial", &empty_schema()).path, "i32");
172 }
173
174 #[test]
175 fn test_int8() {
176 assert_eq!(map_type("int8", &empty_schema()).path, "i64");
177 }
178
179 #[test]
180 fn test_bigint() {
181 assert_eq!(map_type("bigint", &empty_schema()).path, "i64");
182 }
183
184 #[test]
185 fn test_bigserial() {
186 assert_eq!(map_type("bigserial", &empty_schema()).path, "i64");
187 }
188
189 #[test]
190 fn test_float4() {
191 assert_eq!(map_type("float4", &empty_schema()).path, "f32");
192 }
193
194 #[test]
195 fn test_real() {
196 assert_eq!(map_type("real", &empty_schema()).path, "f32");
197 }
198
199 #[test]
200 fn test_float8() {
201 assert_eq!(map_type("float8", &empty_schema()).path, "f64");
202 }
203
204 #[test]
205 fn test_double_precision() {
206 assert_eq!(map_type("double precision", &empty_schema()).path, "f64");
207 }
208
209 #[test]
210 fn test_numeric() {
211 let rt = map_type("numeric", &empty_schema());
212 assert_eq!(rt.path, "Decimal");
213 assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal"));
214 }
215
216 #[test]
217 fn test_decimal() {
218 let rt = map_type("decimal", &empty_schema());
219 assert_eq!(rt.path, "Decimal");
220 }
221
222 #[test]
223 fn test_varchar() {
224 assert_eq!(map_type("varchar", &empty_schema()).path, "String");
225 }
226
227 #[test]
228 fn test_text() {
229 assert_eq!(map_type("text", &empty_schema()).path, "String");
230 }
231
232 #[test]
233 fn test_bpchar() {
234 assert_eq!(map_type("bpchar", &empty_schema()).path, "String");
235 }
236
237 #[test]
238 fn test_citext() {
239 assert_eq!(map_type("citext", &empty_schema()).path, "String");
240 }
241
242 #[test]
243 fn test_name() {
244 assert_eq!(map_type("name", &empty_schema()).path, "String");
245 }
246
247 #[test]
248 fn test_bytea() {
249 assert_eq!(map_type("bytea", &empty_schema()).path, "Vec<u8>");
250 }
251
252 #[test]
253 fn test_uuid() {
254 let rt = map_type("uuid", &empty_schema());
255 assert_eq!(rt.path, "Uuid");
256 assert!(rt.needs_import.as_ref().unwrap().contains("uuid::Uuid"));
257 }
258
259 #[test]
260 fn test_json() {
261 let rt = map_type("json", &empty_schema());
262 assert_eq!(rt.path, "Value");
263 assert!(rt.needs_import.as_ref().unwrap().contains("serde_json"));
264 }
265
266 #[test]
267 fn test_jsonb() {
268 let rt = map_type("jsonb", &empty_schema());
269 assert_eq!(rt.path, "Value");
270 }
271
272 #[test]
273 fn test_timestamp() {
274 let rt = map_type("timestamp", &empty_schema());
275 assert_eq!(rt.path, "NaiveDateTime");
276 assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
277 }
278
279 #[test]
280 fn test_timestamptz() {
281 let rt = map_type("timestamptz", &empty_schema());
282 assert_eq!(rt.path, "DateTime<Utc>");
283 assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
284 }
285
286 #[test]
287 fn test_date() {
288 let rt = map_type("date", &empty_schema());
289 assert_eq!(rt.path, "NaiveDate");
290 }
291
292 #[test]
293 fn test_time() {
294 let rt = map_type("time", &empty_schema());
295 assert_eq!(rt.path, "NaiveTime");
296 }
297
298 #[test]
299 fn test_timetz() {
300 let rt = map_type("timetz", &empty_schema());
301 assert_eq!(rt.path, "NaiveTime");
302 }
303
304 #[test]
305 fn test_inet() {
306 let rt = map_type("inet", &empty_schema());
307 assert_eq!(rt.path, "IpNetwork");
308 assert!(rt.needs_import.as_ref().unwrap().contains("ipnetwork"));
309 }
310
311 #[test]
312 fn test_cidr() {
313 let rt = map_type("cidr", &empty_schema());
314 assert_eq!(rt.path, "IpNetwork");
315 }
316
317 #[test]
318 fn test_oid() {
319 assert_eq!(map_type("oid", &empty_schema()).path, "u32");
320 }
321
322 #[test]
325 fn test_array_int4() {
326 assert_eq!(map_type("_int4", &empty_schema()).path, "Vec<i32>");
327 }
328
329 #[test]
330 fn test_array_text() {
331 assert_eq!(map_type("_text", &empty_schema()).path, "Vec<String>");
332 }
333
334 #[test]
335 fn test_array_uuid() {
336 let rt = map_type("_uuid", &empty_schema());
337 assert_eq!(rt.path, "Vec<Uuid>");
338 assert!(rt.needs_import.is_some());
339 }
340
341 #[test]
342 fn test_array_bool() {
343 assert_eq!(map_type("_bool", &empty_schema()).path, "Vec<bool>");
344 }
345
346 #[test]
347 fn test_array_jsonb() {
348 let rt = map_type("_jsonb", &empty_schema());
349 assert_eq!(rt.path, "Vec<Value>");
350 assert!(rt.needs_import.is_some());
351 }
352
353 #[test]
354 fn test_array_bytea() {
355 assert_eq!(map_type("_bytea", &empty_schema()).path, "Vec<Vec<u8>>");
356 }
357
358 #[test]
361 fn test_enum_status() {
362 let schema = schema_with_enum("status");
363 let rt = map_type("status", &schema);
364 assert_eq!(rt.path, "Status");
365 assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Status"));
366 }
367
368 #[test]
369 fn test_enum_user_role() {
370 let schema = schema_with_enum("user_role");
371 let rt = map_type("user_role", &schema);
372 assert_eq!(rt.path, "UserRole");
373 }
374
375 #[test]
376 fn test_composite_address() {
377 let schema = schema_with_composite("address");
378 let rt = map_type("address", &schema);
379 assert_eq!(rt.path, "Address");
380 assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Address"));
381 }
382
383 #[test]
384 fn test_composite_geo_point() {
385 let schema = schema_with_composite("geo_point");
386 let rt = map_type("geo_point", &schema);
387 assert_eq!(rt.path, "GeoPoint");
388 }
389
390 #[test]
391 fn test_domain_text() {
392 let schema = schema_with_domain("email", "text");
393 let rt = map_type("email", &schema);
394 assert_eq!(rt.path, "String");
395 }
396
397 #[test]
398 fn test_domain_int4() {
399 let schema = schema_with_domain("positive_int", "int4");
400 let rt = map_type("positive_int", &schema);
401 assert_eq!(rt.path, "i32");
402 }
403
404 #[test]
405 fn test_domain_uuid() {
406 let schema = schema_with_domain("my_uuid", "uuid");
407 let rt = map_type("my_uuid", &schema);
408 assert_eq!(rt.path, "Uuid");
409 assert!(rt.needs_import.is_some());
410 }
411
412 #[test]
415 fn test_array_enum() {
416 let schema = schema_with_enum("status");
417 let rt = map_type("_status", &schema);
418 assert_eq!(rt.path, "Vec<Status>");
419 assert!(rt.needs_import.is_some());
420 }
421
422 #[test]
423 fn test_array_composite() {
424 let schema = schema_with_composite("address");
425 let rt = map_type("_address", &schema);
426 assert_eq!(rt.path, "Vec<Address>");
427 }
428
429 #[test]
432 fn test_geometry_fallback() {
433 assert_eq!(map_type("geometry", &empty_schema()).path, "String");
434 }
435
436 #[test]
437 fn test_hstore_fallback() {
438 assert_eq!(map_type("hstore", &empty_schema()).path, "String");
439 }
440}