1#![allow(clippy::result_large_err)]
9#![allow(clippy::cast_possible_truncation)]
11
12use sqlmodel_core::Error;
13use sqlmodel_core::error::TypeError;
14use sqlmodel_core::value::Value;
15
16use super::encode::Format;
17use super::oid;
18
19pub trait TextDecode: Sized {
21 fn decode_text(s: &str) -> Result<Self, Error>;
23}
24
25pub trait BinaryDecode: Sized {
27 fn decode_binary(data: &[u8]) -> Result<Self, Error>;
29}
30
31pub trait Decode: TextDecode + BinaryDecode {
33 fn decode(data: &[u8], format: Format) -> Result<Self, Error> {
35 match format {
36 Format::Text => {
37 let s = std::str::from_utf8(data).map_err(|_| {
38 Error::Type(TypeError {
39 expected: "valid UTF-8",
40 actual: format!("invalid bytes: {:?}", &data[..data.len().min(20)]),
41 column: None,
42 rust_type: None,
43 })
44 })?;
45 Self::decode_text(s)
46 }
47 Format::Binary => Self::decode_binary(data),
48 }
49 }
50}
51
52impl TextDecode for bool {
55 fn decode_text(s: &str) -> Result<Self, Error> {
56 match s {
57 "t" | "true" | "TRUE" | "1" | "y" | "yes" | "on" => Ok(true),
58 "f" | "false" | "FALSE" | "0" | "n" | "no" | "off" => Ok(false),
59 _ => Err(type_error("bool", s)),
60 }
61 }
62}
63
64impl BinaryDecode for bool {
65 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
66 if data.len() != 1 {
67 return Err(binary_length_error("bool", 1, data.len()));
68 }
69 Ok(data[0] != 0)
70 }
71}
72
73impl Decode for bool {}
74
75impl TextDecode for i8 {
78 fn decode_text(s: &str) -> Result<Self, Error> {
79 s.parse().map_err(|_| type_error("i8", s))
80 }
81}
82
83impl BinaryDecode for i8 {
84 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
85 if data.len() != 1 {
86 return Err(binary_length_error("i8", 1, data.len()));
87 }
88 Ok(data[0] as i8)
89 }
90}
91
92impl Decode for i8 {}
93
94impl TextDecode for i16 {
95 fn decode_text(s: &str) -> Result<Self, Error> {
96 s.parse().map_err(|_| type_error("int2", s))
97 }
98}
99
100impl BinaryDecode for i16 {
101 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
102 if data.len() != 2 {
103 return Err(binary_length_error("int2", 2, data.len()));
104 }
105 Ok(i16::from_be_bytes([data[0], data[1]]))
106 }
107}
108
109impl Decode for i16 {}
110
111impl TextDecode for i32 {
112 fn decode_text(s: &str) -> Result<Self, Error> {
113 s.parse().map_err(|_| type_error("int4", s))
114 }
115}
116
117impl BinaryDecode for i32 {
118 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
119 if data.len() != 4 {
120 return Err(binary_length_error("int4", 4, data.len()));
121 }
122 Ok(i32::from_be_bytes([data[0], data[1], data[2], data[3]]))
123 }
124}
125
126impl Decode for i32 {}
127
128impl TextDecode for i64 {
129 fn decode_text(s: &str) -> Result<Self, Error> {
130 s.parse().map_err(|_| type_error("int8", s))
131 }
132}
133
134impl BinaryDecode for i64 {
135 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
136 if data.len() != 8 {
137 return Err(binary_length_error("int8", 8, data.len()));
138 }
139 Ok(i64::from_be_bytes([
140 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
141 ]))
142 }
143}
144
145impl Decode for i64 {}
146
147impl TextDecode for u32 {
150 fn decode_text(s: &str) -> Result<Self, Error> {
151 s.parse().map_err(|_| type_error("oid", s))
153 }
154}
155
156impl BinaryDecode for u32 {
157 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
158 if data.len() != 4 {
159 return Err(binary_length_error("oid", 4, data.len()));
160 }
161 Ok(u32::from_be_bytes([data[0], data[1], data[2], data[3]]))
162 }
163}
164
165impl Decode for u32 {}
166
167impl TextDecode for f32 {
170 fn decode_text(s: &str) -> Result<Self, Error> {
171 match s {
172 "NaN" => Ok(f32::NAN),
173 "Infinity" => Ok(f32::INFINITY),
174 "-Infinity" => Ok(f32::NEG_INFINITY),
175 _ => s.parse().map_err(|_| type_error("float4", s)),
176 }
177 }
178}
179
180impl BinaryDecode for f32 {
181 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
182 if data.len() != 4 {
183 return Err(binary_length_error("float4", 4, data.len()));
184 }
185 Ok(f32::from_be_bytes([data[0], data[1], data[2], data[3]]))
186 }
187}
188
189impl Decode for f32 {}
190
191impl TextDecode for f64 {
192 fn decode_text(s: &str) -> Result<Self, Error> {
193 match s {
194 "NaN" => Ok(f64::NAN),
195 "Infinity" => Ok(f64::INFINITY),
196 "-Infinity" => Ok(f64::NEG_INFINITY),
197 _ => s.parse().map_err(|_| type_error("float8", s)),
198 }
199 }
200}
201
202impl BinaryDecode for f64 {
203 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
204 if data.len() != 8 {
205 return Err(binary_length_error("float8", 8, data.len()));
206 }
207 Ok(f64::from_be_bytes([
208 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
209 ]))
210 }
211}
212
213impl Decode for f64 {}
214
215impl TextDecode for String {
218 fn decode_text(s: &str) -> Result<Self, Error> {
219 Ok(s.to_string())
220 }
221}
222
223impl BinaryDecode for String {
224 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
225 String::from_utf8(data.to_vec()).map_err(|_| {
226 Error::Type(TypeError {
227 expected: "valid UTF-8",
228 actual: format!("invalid bytes: {:?}", &data[..data.len().min(20)]),
229 column: None,
230 rust_type: None,
231 })
232 })
233 }
234}
235
236impl Decode for String {}
237
238impl TextDecode for Vec<u8> {
241 fn decode_text(s: &str) -> Result<Self, Error> {
242 if let Some(hex) = s.strip_prefix("\\x") {
244 decode_hex(hex)
245 } else {
246 decode_bytea_escape(s)
248 }
249 }
250}
251
252impl BinaryDecode for Vec<u8> {
253 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
254 Ok(data.to_vec())
255 }
256}
257
258impl Decode for Vec<u8> {}
259
260impl TextDecode for [u8; 16] {
263 fn decode_text(s: &str) -> Result<Self, Error> {
264 let s = s.replace('-', "");
266 if s.len() != 32 {
267 return Err(type_error("uuid", s));
268 }
269
270 let mut bytes = [0u8; 16];
271 for (i, byte) in bytes.iter_mut().enumerate() {
272 *byte =
273 u8::from_str_radix(&s[i * 2..i * 2 + 2], 16).map_err(|_| type_error("uuid", &s))?;
274 }
275 Ok(bytes)
276 }
277}
278
279impl BinaryDecode for [u8; 16] {
280 fn decode_binary(data: &[u8]) -> Result<Self, Error> {
281 if data.len() != 16 {
282 return Err(binary_length_error("uuid", 16, data.len()));
283 }
284 let mut bytes = [0u8; 16];
285 bytes.copy_from_slice(data);
286 Ok(bytes)
287 }
288}
289
290impl Decode for [u8; 16] {}
291
292const PG_EPOCH_OFFSET_DAYS: i32 = 10_957;
296
297const PG_EPOCH_OFFSET_MICROS: i64 = 946_684_800_000_000;
299
300pub fn decode_date_days(pg_days: i32) -> i32 {
304 pg_days + PG_EPOCH_OFFSET_DAYS
305}
306
307pub fn decode_timestamp_micros(pg_micros: i64) -> i64 {
311 pg_micros + PG_EPOCH_OFFSET_MICROS
312}
313
314pub fn parse_date_string(s: &str) -> Result<i32, Error> {
318 let parts: Vec<&str> = s.split('-').collect();
319 if parts.len() != 3 {
320 return Err(type_error("date", s));
321 }
322
323 let year: i32 = parts[0].parse().map_err(|_| type_error("date", s))?;
324 let month: u32 = parts[1].parse().map_err(|_| type_error("date", s))?;
325 let day: u32 = parts[2].parse().map_err(|_| type_error("date", s))?;
326
327 Ok(date_to_days(year, month, day))
329}
330
331pub fn parse_time_string(s: &str) -> Result<i64, Error> {
335 let (time_part, micros_part) = if let Some(pos) = s.find('.') {
336 (&s[..pos], Some(&s[pos + 1..]))
337 } else {
338 (s, None)
339 };
340
341 let parts: Vec<&str> = time_part.split(':').collect();
342 if parts.len() < 2 || parts.len() > 3 {
343 return Err(type_error("time", s));
344 }
345
346 let hours: i64 = parts[0].parse().map_err(|_| type_error("time", s))?;
347 let mins: i64 = parts[1].parse().map_err(|_| type_error("time", s))?;
348 let secs: i64 = if parts.len() == 3 {
349 parts[2].parse().map_err(|_| type_error("time", s))?
350 } else {
351 0
352 };
353
354 let mut micros = (hours * 3600 + mins * 60 + secs) * 1_000_000;
355
356 if let Some(frac) = micros_part {
357 let frac_str = if frac.len() > 6 { &frac[..6] } else { frac };
359 let frac_micros: i64 = frac_str.parse().map_err(|_| type_error("time", s))?;
360 let multiplier = 10_i64.pow(6 - frac_str.len() as u32);
361 micros += frac_micros * multiplier;
362 }
363
364 Ok(micros)
365}
366
367pub fn parse_timestamp_string(s: &str) -> Result<i64, Error> {
371 let s = s.replace('T', " ");
373
374 let s = if let Some(pos) = s.find('+') {
376 &s[..pos]
377 } else if let Some(pos) = s.rfind('-') {
378 if pos > 10 { &s[..pos] } else { &s }
380 } else {
381 &s
382 };
383
384 let parts: Vec<&str> = s.split(' ').collect();
385 if parts.len() != 2 {
386 if parts.len() == 1 {
388 let days = parse_date_string(parts[0])?;
389 return Ok(i64::from(days) * 86_400 * 1_000_000);
390 }
391 return Err(type_error("timestamp", s));
392 }
393
394 let days = parse_date_string(parts[0])?;
395 let time_micros = parse_time_string(parts[1])?;
396
397 Ok(i64::from(days) * 86_400 * 1_000_000 + time_micros)
398}
399
400pub fn decode_value(type_oid: u32, data: Option<&[u8]>, format: Format) -> Result<Value, Error> {
409 let Some(data) = data else {
410 return Ok(Value::Null);
411 };
412
413 match (type_oid, format) {
414 (oid::BOOL, Format::Binary) => Ok(Value::Bool(bool::decode_binary(data)?)),
416 (oid::BOOL, Format::Text) => {
417 let s = std::str::from_utf8(data).map_err(utf8_error)?;
418 Ok(Value::Bool(bool::decode_text(s)?))
419 }
420
421 (oid::INT2, Format::Binary) => Ok(Value::SmallInt(i16::decode_binary(data)?)),
423 (oid::INT2, Format::Text) => {
424 let s = std::str::from_utf8(data).map_err(utf8_error)?;
425 Ok(Value::SmallInt(i16::decode_text(s)?))
426 }
427
428 (oid::INT4, Format::Binary) => Ok(Value::Int(i32::decode_binary(data)?)),
429 (oid::INT4, Format::Text) => {
430 let s = std::str::from_utf8(data).map_err(utf8_error)?;
431 Ok(Value::Int(i32::decode_text(s)?))
432 }
433
434 (oid::INT8, Format::Binary) => Ok(Value::BigInt(i64::decode_binary(data)?)),
435 (oid::INT8, Format::Text) => {
436 let s = std::str::from_utf8(data).map_err(utf8_error)?;
437 Ok(Value::BigInt(i64::decode_text(s)?))
438 }
439
440 (oid::FLOAT4, Format::Binary) => Ok(Value::Float(f32::decode_binary(data)?)),
442 (oid::FLOAT4, Format::Text) => {
443 let s = std::str::from_utf8(data).map_err(utf8_error)?;
444 Ok(Value::Float(f32::decode_text(s)?))
445 }
446
447 (oid::FLOAT8, Format::Binary) => Ok(Value::Double(f64::decode_binary(data)?)),
448 (oid::FLOAT8, Format::Text) => {
449 let s = std::str::from_utf8(data).map_err(utf8_error)?;
450 Ok(Value::Double(f64::decode_text(s)?))
451 }
452
453 (oid::NUMERIC, _) => {
455 let s = std::str::from_utf8(data).map_err(utf8_error)?;
456 Ok(Value::Decimal(s.to_string()))
457 }
458
459 (oid::TEXT | oid::VARCHAR | oid::BPCHAR | oid::NAME | oid::CHAR, _) => {
461 Ok(Value::Text(String::decode_binary(data)?))
462 }
463
464 (oid::BYTEA, Format::Binary) => Ok(Value::Bytes(data.to_vec())),
466 (oid::BYTEA, Format::Text) => {
467 let s = std::str::from_utf8(data).map_err(utf8_error)?;
468 Ok(Value::Bytes(Vec::<u8>::decode_text(s)?))
469 }
470
471 (oid::DATE, Format::Binary) => {
473 let pg_days = i32::decode_binary(data)?;
474 Ok(Value::Date(decode_date_days(pg_days)))
475 }
476 (oid::DATE, Format::Text) => {
477 let s = std::str::from_utf8(data).map_err(utf8_error)?;
478 Ok(Value::Date(parse_date_string(s)?))
479 }
480
481 (oid::TIME | oid::TIMETZ, Format::Binary) => {
483 let micros = i64::decode_binary(data)?;
484 Ok(Value::Time(micros))
485 }
486 (oid::TIME | oid::TIMETZ, Format::Text) => {
487 let s = std::str::from_utf8(data).map_err(utf8_error)?;
488 Ok(Value::Time(parse_time_string(s)?))
489 }
490
491 (oid::TIMESTAMP, Format::Binary) => {
493 let pg_micros = i64::decode_binary(data)?;
494 Ok(Value::Timestamp(decode_timestamp_micros(pg_micros)))
495 }
496 (oid::TIMESTAMP, Format::Text) => {
497 let s = std::str::from_utf8(data).map_err(utf8_error)?;
498 Ok(Value::Timestamp(parse_timestamp_string(s)?))
499 }
500
501 (oid::TIMESTAMPTZ, Format::Binary) => {
503 let pg_micros = i64::decode_binary(data)?;
504 Ok(Value::TimestampTz(decode_timestamp_micros(pg_micros)))
505 }
506 (oid::TIMESTAMPTZ, Format::Text) => {
507 let s = std::str::from_utf8(data).map_err(utf8_error)?;
508 Ok(Value::TimestampTz(parse_timestamp_string(s)?))
509 }
510
511 (oid::UUID, Format::Binary) => {
513 let bytes = <[u8; 16]>::decode_binary(data)?;
514 Ok(Value::Uuid(bytes))
515 }
516 (oid::UUID, Format::Text) => {
517 let s = std::str::from_utf8(data).map_err(utf8_error)?;
518 let bytes = <[u8; 16]>::decode_text(s)?;
519 Ok(Value::Uuid(bytes))
520 }
521
522 (oid::JSON, _) => {
524 let s = std::str::from_utf8(data).map_err(utf8_error)?;
525 let json: serde_json::Value =
526 serde_json::from_str(s).map_err(|e| type_error_with_source("json", s, e))?;
527 Ok(Value::Json(json))
528 }
529
530 (oid::JSONB, Format::Binary) => {
532 if data.is_empty() {
533 return Err(type_error("jsonb", "empty data"));
534 }
535 let json_data = &data[1..];
537 let s = std::str::from_utf8(json_data).map_err(utf8_error)?;
538 let json: serde_json::Value =
539 serde_json::from_str(s).map_err(|e| type_error_with_source("jsonb", s, e))?;
540 Ok(Value::Json(json))
541 }
542 (oid::JSONB, Format::Text) => {
543 let s = std::str::from_utf8(data).map_err(utf8_error)?;
544 let json: serde_json::Value =
545 serde_json::from_str(s).map_err(|e| type_error_with_source("jsonb", s, e))?;
546 Ok(Value::Json(json))
547 }
548
549 (oid::OID | oid::XID | oid::CID, Format::Binary) => {
551 let v = u32::decode_binary(data)?;
552 Ok(Value::Int(v as i32))
553 }
554 (oid::OID | oid::XID | oid::CID, Format::Text) => {
555 let s = std::str::from_utf8(data).map_err(utf8_error)?;
556 let v = u32::decode_text(s)?;
557 Ok(Value::Int(v as i32))
558 }
559
560 (_, _) => Ok(Value::Text(String::decode_binary(data)?)),
562 }
563}
564
565fn type_error(expected: &'static str, value: impl std::fmt::Display) -> Error {
568 Error::Type(TypeError {
569 expected,
570 actual: format!("invalid value: {}", value),
571 column: None,
572 rust_type: None,
573 })
574}
575
576fn type_error_with_source<E: std::error::Error>(
577 expected: &'static str,
578 value: impl std::fmt::Display,
579 source: E,
580) -> Error {
581 Error::Type(TypeError {
582 expected,
583 actual: format!("invalid value: {} ({})", value, source),
584 column: None,
585 rust_type: None,
586 })
587}
588
589fn binary_length_error(type_name: &'static str, expected: usize, actual: usize) -> Error {
590 Error::Type(TypeError {
591 expected: type_name,
592 actual: format!("expected {} bytes, got {}", expected, actual),
593 column: None,
594 rust_type: None,
595 })
596}
597
598fn utf8_error(_e: std::str::Utf8Error) -> Error {
599 Error::Type(TypeError {
600 expected: "valid UTF-8",
601 actual: "invalid UTF-8 bytes".to_string(),
602 column: None,
603 rust_type: None,
604 })
605}
606
607fn decode_hex(s: &str) -> Result<Vec<u8>, Error> {
609 let s = s.trim();
610 if s.len() % 2 != 0 {
611 return Err(type_error("bytea hex", s));
612 }
613
614 let mut bytes = Vec::with_capacity(s.len() / 2);
615 for i in (0..s.len()).step_by(2) {
616 let byte = u8::from_str_radix(&s[i..i + 2], 16).map_err(|_| type_error("bytea hex", s))?;
617 bytes.push(byte);
618 }
619 Ok(bytes)
620}
621
622fn decode_bytea_escape(s: &str) -> Result<Vec<u8>, Error> {
624 let mut bytes = Vec::with_capacity(s.len());
625 let mut chars = s.chars().peekable();
626
627 while let Some(c) = chars.next() {
628 if c == '\\' {
629 match chars.peek() {
630 Some('\\') => {
631 chars.next();
632 bytes.push(b'\\');
633 }
634 Some(c) if c.is_ascii_digit() => {
635 let mut octal = String::with_capacity(3);
637 for _ in 0..3 {
638 if let Some(&c) = chars.peek() {
639 if c.is_ascii_digit() {
640 octal.push(c);
641 chars.next();
642 } else {
643 break;
644 }
645 }
646 }
647 let byte =
648 u8::from_str_radix(&octal, 8).map_err(|_| type_error("bytea escape", s))?;
649 bytes.push(byte);
650 }
651 _ => {
652 bytes.push(b'\\');
654 }
655 }
656 } else {
657 bytes.push(c as u8);
658 }
659 }
660
661 Ok(bytes)
662}
663
664fn date_to_days(year: i32, month: u32, day: u32) -> i32 {
666 let y = if month <= 2 { year - 1 } else { year };
668 let era = if y >= 0 { y } else { y - 399 } / 400;
669 let yoe = (y - era * 400) as u32;
670 let doy = (153 * (if month > 2 { month - 3 } else { month + 9 }) + 2) / 5 + day - 1;
671 let doe = yoe * 365 + yoe / 4 - yoe / 100 + doy;
672 era * 146_097 + doe as i32 - 719_468
673}
674
675#[cfg(test)]
676mod tests {
677 use super::*;
678
679 #[test]
680 fn test_bool_decoding() {
681 assert!(bool::decode_text("t").unwrap());
682 assert!(bool::decode_text("true").unwrap());
683 assert!(!bool::decode_text("f").unwrap());
684 assert!(!bool::decode_text("false").unwrap());
685
686 assert!(bool::decode_binary(&[1]).unwrap());
687 assert!(!bool::decode_binary(&[0]).unwrap());
688 }
689
690 #[test]
691 fn test_integer_decoding() {
692 assert_eq!(i32::decode_text("42").unwrap(), 42);
693 assert_eq!(i32::decode_text("-100").unwrap(), -100);
694
695 assert_eq!(i32::decode_binary(&[0, 0, 0, 42]).unwrap(), 42);
696 assert_eq!(i32::decode_binary(&[0, 0, 1, 0]).unwrap(), 256);
697 }
698
699 #[test]
700 fn test_float_decoding() {
701 assert!(f64::decode_text("NaN").unwrap().is_nan());
702 assert!(f64::decode_text("Infinity").unwrap().is_infinite());
703 assert!(f64::decode_text("-Infinity").unwrap().is_infinite());
704 let decoded = f64::decode_text("1.5").unwrap();
707 assert!((decoded - 1.5).abs() < f64::EPSILON);
708 }
709
710 #[test]
711 fn test_bytea_hex_decoding() {
712 let bytes = Vec::<u8>::decode_text("\\xdeadbeef").unwrap();
713 assert_eq!(bytes, vec![0xDE, 0xAD, 0xBE, 0xEF]);
714 }
715
716 #[test]
717 fn test_uuid_decoding() {
718 let uuid = <[u8; 16]>::decode_text("55069c47-868b-4a08-a47f-3653262bce35").unwrap();
719 assert_eq!(
720 uuid,
721 [
722 0x55, 0x06, 0x9c, 0x47, 0x86, 0x8b, 0x4a, 0x08, 0xa4, 0x7f, 0x36, 0x53, 0x26, 0x2b,
723 0xce, 0x35
724 ]
725 );
726 }
727
728 #[test]
729 fn test_date_parsing() {
730 assert_eq!(parse_date_string("2000-01-01").unwrap(), 10_957);
732 assert_eq!(parse_date_string("1970-01-01").unwrap(), 0);
734 }
735
736 #[test]
737 fn test_time_parsing() {
738 assert_eq!(parse_time_string("00:00:00").unwrap(), 0);
739 assert_eq!(parse_time_string("01:00:00").unwrap(), 3_600_000_000);
740 assert_eq!(
741 parse_time_string("12:30:45.123456").unwrap(),
742 45_045_123_456
743 );
744 }
745
746 #[test]
747 fn test_decode_value_null() {
748 let value = decode_value(oid::INT4, None, Format::Binary).unwrap();
749 assert!(matches!(value, Value::Null));
750 }
751
752 #[test]
753 fn test_decode_value_int() {
754 let value = decode_value(oid::INT4, Some(&[0, 0, 0, 42]), Format::Binary).unwrap();
755 assert!(matches!(value, Value::Int(42)));
756 }
757}