1use heck::ToUpperCamelCase;
2
3use super::RustType;
4use crate::introspect::SchemaInfo;
5
6pub fn map_type(udt_name: &str, schema_info: &SchemaInfo) -> RustType {
7 if let Some(inner) = udt_name.strip_prefix('_') {
9 let inner_type = map_type(inner, schema_info);
10 return inner_type.wrap_vec();
11 }
12
13 if schema_info.enums.iter().any(|e| e.name == udt_name) {
15 let name = udt_name.to_upper_camel_case();
16 return RustType::with_import(&name, &format!("use super::types::{};", name));
17 }
18
19 if schema_info.composite_types.iter().any(|c| c.name == udt_name) {
21 let name = udt_name.to_upper_camel_case();
22 return RustType::with_import(&name, &format!("use super::types::{};", name));
23 }
24
25 if let Some(domain) = schema_info.domains.iter().find(|d| d.name == udt_name) {
27 return map_type(&domain.base_type, schema_info);
29 }
30
31 match udt_name {
32 "bool" => RustType::simple("bool"),
33 "int2" | "smallint" | "smallserial" => RustType::simple("i16"),
34 "int4" | "int" | "integer" | "serial" => RustType::simple("i32"),
35 "int8" | "bigint" | "bigserial" => RustType::simple("i64"),
36 "float4" | "real" => RustType::simple("f32"),
37 "float8" | "double precision" => RustType::simple("f64"),
38 "numeric" | "decimal" => {
39 RustType::with_import("Decimal", "use rust_decimal::Decimal;")
40 }
41 "varchar" | "text" | "bpchar" | "char" | "name" | "citext" => RustType::simple("String"),
42 "bytea" => RustType::simple("Vec<u8>"),
43 "timestamp" | "timestamp without time zone" => {
44 RustType::with_import("NaiveDateTime", "use chrono::NaiveDateTime;")
45 }
46 "timestamptz" | "timestamp with time zone" => {
47 RustType::with_import("DateTime<Utc>", "use chrono::{DateTime, Utc};")
48 }
49 "date" => RustType::with_import("NaiveDate", "use chrono::NaiveDate;"),
50 "time" | "time without time zone" => {
51 RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
52 }
53 "timetz" | "time with time zone" => {
54 RustType::with_import("NaiveTime", "use chrono::NaiveTime;")
55 }
56 "uuid" => RustType::with_import("Uuid", "use uuid::Uuid;"),
57 "json" | "jsonb" => {
58 RustType::with_import("Value", "use serde_json::Value;")
59 }
60 "inet" | "cidr" => {
61 RustType::with_import("IpNetwork", "use ipnetwork::IpNetwork;")
62 }
63 "oid" => RustType::simple("u32"),
64 _ => RustType::simple("String"), }
66}
67
68#[cfg(test)]
69mod tests {
70 use super::*;
71 use crate::introspect::{CompositeTypeInfo, DomainInfo, EnumInfo};
72
73 fn empty_schema() -> SchemaInfo {
74 SchemaInfo::default()
75 }
76
77 fn schema_with_enum(name: &str) -> SchemaInfo {
78 SchemaInfo {
79 enums: vec![EnumInfo {
80 schema_name: "public".to_string(),
81 name: name.to_string(),
82 variants: vec!["a".to_string()],
83 }],
84 ..Default::default()
85 }
86 }
87
88 fn schema_with_composite(name: &str) -> SchemaInfo {
89 SchemaInfo {
90 composite_types: vec![CompositeTypeInfo {
91 schema_name: "public".to_string(),
92 name: name.to_string(),
93 fields: vec![],
94 }],
95 ..Default::default()
96 }
97 }
98
99 fn schema_with_domain(name: &str, base: &str) -> SchemaInfo {
100 SchemaInfo {
101 domains: vec![DomainInfo {
102 schema_name: "public".to_string(),
103 name: name.to_string(),
104 base_type: base.to_string(),
105 }],
106 ..Default::default()
107 }
108 }
109
110 #[test]
113 fn test_bool() {
114 assert_eq!(map_type("bool", &empty_schema()).path, "bool");
115 }
116
117 #[test]
118 fn test_int2() {
119 assert_eq!(map_type("int2", &empty_schema()).path, "i16");
120 }
121
122 #[test]
123 fn test_smallint() {
124 assert_eq!(map_type("smallint", &empty_schema()).path, "i16");
125 }
126
127 #[test]
128 fn test_smallserial() {
129 assert_eq!(map_type("smallserial", &empty_schema()).path, "i16");
130 }
131
132 #[test]
133 fn test_int4() {
134 assert_eq!(map_type("int4", &empty_schema()).path, "i32");
135 }
136
137 #[test]
138 fn test_integer() {
139 assert_eq!(map_type("integer", &empty_schema()).path, "i32");
140 }
141
142 #[test]
143 fn test_serial() {
144 assert_eq!(map_type("serial", &empty_schema()).path, "i32");
145 }
146
147 #[test]
148 fn test_int8() {
149 assert_eq!(map_type("int8", &empty_schema()).path, "i64");
150 }
151
152 #[test]
153 fn test_bigint() {
154 assert_eq!(map_type("bigint", &empty_schema()).path, "i64");
155 }
156
157 #[test]
158 fn test_bigserial() {
159 assert_eq!(map_type("bigserial", &empty_schema()).path, "i64");
160 }
161
162 #[test]
163 fn test_float4() {
164 assert_eq!(map_type("float4", &empty_schema()).path, "f32");
165 }
166
167 #[test]
168 fn test_real() {
169 assert_eq!(map_type("real", &empty_schema()).path, "f32");
170 }
171
172 #[test]
173 fn test_float8() {
174 assert_eq!(map_type("float8", &empty_schema()).path, "f64");
175 }
176
177 #[test]
178 fn test_double_precision() {
179 assert_eq!(map_type("double precision", &empty_schema()).path, "f64");
180 }
181
182 #[test]
183 fn test_numeric() {
184 let rt = map_type("numeric", &empty_schema());
185 assert_eq!(rt.path, "Decimal");
186 assert!(rt.needs_import.as_ref().unwrap().contains("rust_decimal"));
187 }
188
189 #[test]
190 fn test_decimal() {
191 let rt = map_type("decimal", &empty_schema());
192 assert_eq!(rt.path, "Decimal");
193 }
194
195 #[test]
196 fn test_varchar() {
197 assert_eq!(map_type("varchar", &empty_schema()).path, "String");
198 }
199
200 #[test]
201 fn test_text() {
202 assert_eq!(map_type("text", &empty_schema()).path, "String");
203 }
204
205 #[test]
206 fn test_bpchar() {
207 assert_eq!(map_type("bpchar", &empty_schema()).path, "String");
208 }
209
210 #[test]
211 fn test_citext() {
212 assert_eq!(map_type("citext", &empty_schema()).path, "String");
213 }
214
215 #[test]
216 fn test_name() {
217 assert_eq!(map_type("name", &empty_schema()).path, "String");
218 }
219
220 #[test]
221 fn test_bytea() {
222 assert_eq!(map_type("bytea", &empty_schema()).path, "Vec<u8>");
223 }
224
225 #[test]
226 fn test_uuid() {
227 let rt = map_type("uuid", &empty_schema());
228 assert_eq!(rt.path, "Uuid");
229 assert!(rt.needs_import.as_ref().unwrap().contains("uuid::Uuid"));
230 }
231
232 #[test]
233 fn test_json() {
234 let rt = map_type("json", &empty_schema());
235 assert_eq!(rt.path, "Value");
236 assert!(rt.needs_import.as_ref().unwrap().contains("serde_json"));
237 }
238
239 #[test]
240 fn test_jsonb() {
241 let rt = map_type("jsonb", &empty_schema());
242 assert_eq!(rt.path, "Value");
243 }
244
245 #[test]
246 fn test_timestamp() {
247 let rt = map_type("timestamp", &empty_schema());
248 assert_eq!(rt.path, "NaiveDateTime");
249 assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
250 }
251
252 #[test]
253 fn test_timestamptz() {
254 let rt = map_type("timestamptz", &empty_schema());
255 assert_eq!(rt.path, "DateTime<Utc>");
256 assert!(rt.needs_import.as_ref().unwrap().contains("chrono"));
257 }
258
259 #[test]
260 fn test_date() {
261 let rt = map_type("date", &empty_schema());
262 assert_eq!(rt.path, "NaiveDate");
263 }
264
265 #[test]
266 fn test_time() {
267 let rt = map_type("time", &empty_schema());
268 assert_eq!(rt.path, "NaiveTime");
269 }
270
271 #[test]
272 fn test_timetz() {
273 let rt = map_type("timetz", &empty_schema());
274 assert_eq!(rt.path, "NaiveTime");
275 }
276
277 #[test]
278 fn test_inet() {
279 let rt = map_type("inet", &empty_schema());
280 assert_eq!(rt.path, "IpNetwork");
281 assert!(rt.needs_import.as_ref().unwrap().contains("ipnetwork"));
282 }
283
284 #[test]
285 fn test_cidr() {
286 let rt = map_type("cidr", &empty_schema());
287 assert_eq!(rt.path, "IpNetwork");
288 }
289
290 #[test]
291 fn test_oid() {
292 assert_eq!(map_type("oid", &empty_schema()).path, "u32");
293 }
294
295 #[test]
298 fn test_array_int4() {
299 assert_eq!(map_type("_int4", &empty_schema()).path, "Vec<i32>");
300 }
301
302 #[test]
303 fn test_array_text() {
304 assert_eq!(map_type("_text", &empty_schema()).path, "Vec<String>");
305 }
306
307 #[test]
308 fn test_array_uuid() {
309 let rt = map_type("_uuid", &empty_schema());
310 assert_eq!(rt.path, "Vec<Uuid>");
311 assert!(rt.needs_import.is_some());
312 }
313
314 #[test]
315 fn test_array_bool() {
316 assert_eq!(map_type("_bool", &empty_schema()).path, "Vec<bool>");
317 }
318
319 #[test]
320 fn test_array_jsonb() {
321 let rt = map_type("_jsonb", &empty_schema());
322 assert_eq!(rt.path, "Vec<Value>");
323 assert!(rt.needs_import.is_some());
324 }
325
326 #[test]
327 fn test_array_bytea() {
328 assert_eq!(map_type("_bytea", &empty_schema()).path, "Vec<Vec<u8>>");
329 }
330
331 #[test]
334 fn test_enum_status() {
335 let schema = schema_with_enum("status");
336 let rt = map_type("status", &schema);
337 assert_eq!(rt.path, "Status");
338 assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Status"));
339 }
340
341 #[test]
342 fn test_enum_user_role() {
343 let schema = schema_with_enum("user_role");
344 let rt = map_type("user_role", &schema);
345 assert_eq!(rt.path, "UserRole");
346 }
347
348 #[test]
349 fn test_composite_address() {
350 let schema = schema_with_composite("address");
351 let rt = map_type("address", &schema);
352 assert_eq!(rt.path, "Address");
353 assert!(rt.needs_import.as_ref().unwrap().contains("super::types::Address"));
354 }
355
356 #[test]
357 fn test_composite_geo_point() {
358 let schema = schema_with_composite("geo_point");
359 let rt = map_type("geo_point", &schema);
360 assert_eq!(rt.path, "GeoPoint");
361 }
362
363 #[test]
364 fn test_domain_text() {
365 let schema = schema_with_domain("email", "text");
366 let rt = map_type("email", &schema);
367 assert_eq!(rt.path, "String");
368 }
369
370 #[test]
371 fn test_domain_int4() {
372 let schema = schema_with_domain("positive_int", "int4");
373 let rt = map_type("positive_int", &schema);
374 assert_eq!(rt.path, "i32");
375 }
376
377 #[test]
378 fn test_domain_uuid() {
379 let schema = schema_with_domain("my_uuid", "uuid");
380 let rt = map_type("my_uuid", &schema);
381 assert_eq!(rt.path, "Uuid");
382 assert!(rt.needs_import.is_some());
383 }
384
385 #[test]
388 fn test_array_enum() {
389 let schema = schema_with_enum("status");
390 let rt = map_type("_status", &schema);
391 assert_eq!(rt.path, "Vec<Status>");
392 assert!(rt.needs_import.is_some());
393 }
394
395 #[test]
396 fn test_array_composite() {
397 let schema = schema_with_composite("address");
398 let rt = map_type("_address", &schema);
399 assert_eq!(rt.path, "Vec<Address>");
400 }
401
402 #[test]
405 fn test_geometry_fallback() {
406 assert_eq!(map_type("geometry", &empty_schema()).path, "String");
407 }
408
409 #[test]
410 fn test_hstore_fallback() {
411 assert_eq!(map_type("hstore", &empty_schema()).path, "String");
412 }
413}