1use deserializers::PgRowDeserializer;
2use serde::de::Error;
3use serde::de::{value::Error as DeError, Deserialize};
4
5use sqlx::postgres::{PgRow, PgValueRef};
6
7pub fn from_pg_row<T>(row: PgRow) -> Result<T, DeError>
9where
10 T: for<'de> Deserialize<'de>,
11{
12 let deserializer = PgRowDeserializer::new(&row);
13 T::deserialize(deserializer)
14}
15
16fn decode_raw_pg<'a, T>(raw_value: PgValueRef<'a>) -> Result<T, DeError>
17where
18 T: sqlx::Decode<'a, sqlx::Postgres>,
19{
20 T::decode(raw_value).map_err(|err| {
21 DeError::custom(format!(
22 "Failed to decode {} value: {:?}",
23 std::any::type_name::<T>(),
24 err,
25 ))
26 })
27}
28
29mod seq_access {
30 use std::fmt::Debug;
31
32 use serde::de::{value::Error as DeError, DeserializeSeed, SeqAccess, Visitor};
33 use serde::ser::Error as _;
34 use serde::{de, forward_to_deserialize_any};
35 use sqlx::{postgres::PgValueRef, Row};
36
37 use crate::{
38 decode_raw_pg,
39 deserializers::{PgRowDeserializer, PgValueDeserializer},
40 };
41
42 pub(crate) struct PgRowSeqAccess<'a> {
44 pub(crate) deserializer: PgRowDeserializer<'a>,
45 pub(crate) num_cols: usize,
46 }
47
48 impl<'de, 'a> SeqAccess<'de> for PgRowSeqAccess<'a> {
49 type Error = DeError;
50
51 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
52 where
53 T: DeserializeSeed<'de>,
54 {
55 if self.deserializer.index < self.num_cols {
56 let value = self
57 .deserializer
58 .row
59 .try_get_raw(self.deserializer.index)
60 .map_err(DeError::custom)?;
61
62 let pg_value_deserializer = PgValueDeserializer { value };
64
65 self.deserializer.index += 1;
66
67 seed.deserialize(pg_value_deserializer).map(Some)
69 } else {
70 Ok(None)
71 }
72 }
73 }
74
75 use serde::de::IntoDeserializer;
76
77 pub struct PgArraySeqAccess<T> {
81 iter: std::vec::IntoIter<Option<T>>,
82 }
83
84 impl<'de, 'a, T> PgArraySeqAccess<T>
85 where
86 T: sqlx::Decode<'a, sqlx::Postgres> + Debug,
87 {
88 pub fn new(value: PgValueRef<'a>) -> Result<Self, DeError>
89 where
90 Vec<Option<T>>: sqlx::Decode<'a, sqlx::Postgres> + Debug,
91 {
92 let vec: Vec<Option<T>> = decode_raw_pg(value)?;
93
94 Ok(PgArraySeqAccess {
95 iter: vec.into_iter(),
96 })
97 }
98 }
99
100 impl<'de, T> SeqAccess<'de> for PgArraySeqAccess<T>
101 where
102 T: IntoDeserializer<'de, DeError>,
103 {
104 type Error = DeError;
105
106 fn next_element_seed<U>(&mut self, seed: U) -> Result<Option<U::Value>, Self::Error>
107 where
108 U: DeserializeSeed<'de>,
109 {
110 let Some(value) = self.iter.next() else {
111 return Ok(None);
112 };
113
114 seed.deserialize(PgArrayElementDeserializer { value })
115 .map(Some)
116 }
117 }
118
119 struct PgArrayElementDeserializer<T> {
121 pub value: Option<T>,
122 }
123
124 impl<'de, T> de::Deserializer<'de> for PgArrayElementDeserializer<T>
125 where
126 T: IntoDeserializer<'de, DeError>,
127 {
128 type Error = DeError;
129
130 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
131 where
132 V: Visitor<'de>,
133 {
134 match self.value {
135 Some(v) => visitor.visit_some(v.into_deserializer()),
136 None => visitor.visit_none(),
137 }
138 }
139
140 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
141 where
142 V: Visitor<'de>,
143 {
144 match self.value {
145 Some(v) => v.into_deserializer().deserialize_any(visitor),
146 None => Err(DeError::custom(
147 "unexpected null in non-optional array element",
148 )),
149 }
150 }
151
152 forward_to_deserialize_any! {
153 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
154 bytes byte_buf unit unit_struct newtype_struct seq tuple tuple_struct
155 map struct enum identifier ignored_any
156 }
157 }
158}
159
160mod map_access {
161 use serde::de::{self, value::Error as DeError, IntoDeserializer, MapAccess};
162 use serde::ser::Error as _;
163
164 use sqlx::{Column, Row};
165
166 use crate::deserializers::{PgRowDeserializer, PgValueDeserializer};
167
168 pub(crate) struct PgRowMapAccess<'a> {
169 pub(crate) deserializer: PgRowDeserializer<'a>,
170 pub(crate) num_cols: usize,
171 }
172
173 impl<'de, 'a> MapAccess<'de> for PgRowMapAccess<'a> {
174 type Error = DeError;
175
176 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
177 where
178 K: de::DeserializeSeed<'de>,
179 {
180 if self.deserializer.index < self.num_cols {
181 let col_name = self.deserializer.row.columns()[self.deserializer.index].name();
182 seed.deserialize(col_name.into_deserializer()).map(Some)
184 } else {
185 Ok(None)
186 }
187 }
188
189 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
190 where
191 V: de::DeserializeSeed<'de>,
192 {
193 let value = self
194 .deserializer
195 .row
196 .try_get_raw(self.deserializer.index)
197 .map_err(DeError::custom)?;
198 let pg_type_deserializer = PgValueDeserializer { value };
199
200 self.deserializer.index += 1;
201
202 seed.deserialize(pg_type_deserializer)
203 }
204 }
205}
206
207mod deserializers {
208 use crate::decode_raw_pg;
209 use crate::json::PgJson;
210 use crate::map_access::PgRowMapAccess;
211 use crate::seq_access::{PgArraySeqAccess, PgRowSeqAccess};
212 use serde::de::{value::Error as DeError, Deserializer, Visitor};
213 use serde::de::{Error as _, IntoDeserializer};
214 use serde::forward_to_deserialize_any;
215 use sqlx::postgres::{PgRow, PgValueRef};
216 use sqlx::{Row, TypeInfo, ValueRef};
217
218 #[derive(Clone, Copy)]
219 pub struct PgRowDeserializer<'a> {
220 pub(crate) row: &'a PgRow,
221 pub(crate) index: usize,
222 }
223
224 impl<'a> PgRowDeserializer<'a> {
225 pub fn new(row: &'a PgRow) -> Self {
226 PgRowDeserializer { row, index: 0 }
227 }
228
229 #[allow(unused)]
230 pub fn is_json(&self) -> bool {
231 self.row.try_get_raw(0).map_or(false, |value| {
232 matches!(value.type_info().name(), "JSON" | "JSONB")
233 })
234 }
235 }
236
237 impl<'de, 'a> Deserializer<'de> for PgRowDeserializer<'a> {
238 type Error = DeError;
239
240 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
241 where
242 V: Visitor<'de>,
243 {
244 let raw_value = self.row.try_get_raw(0).map_err(DeError::custom)?;
245
246 if raw_value.is_null() {
247 visitor.visit_none()
248 } else {
249 visitor.visit_some(self)
250 }
251 }
252
253 fn deserialize_newtype_struct<V>(
254 self,
255 _name: &'static str,
256 visitor: V,
257 ) -> Result<V::Value, Self::Error>
258 where
259 V: Visitor<'de>,
260 {
261 visitor.visit_newtype_struct(self)
262 }
263
264 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265 where
266 V: Visitor<'de>,
267 {
268 match self.row.columns().len() {
269 0 => return visitor.visit_unit(),
270 1 => {}
271 _n => {
272 return self.deserialize_seq(visitor);
273 }
274 };
275
276 let raw_value = self.row.try_get_raw(self.index).map_err(DeError::custom)?;
277 let type_info = raw_value.type_info();
278 let type_name = type_info.name();
279
280 if raw_value.is_null() {
281 return visitor.visit_none();
282 }
283
284 if type_name.ends_with("[]") {
286 return self.deserialize_seq(visitor);
287 }
288
289 let deserializer = PgValueDeserializer { value: raw_value };
291
292 deserializer.deserialize_any(visitor)
293 }
294
295 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
297 where
298 V: Visitor<'de>,
299 {
300 visitor.visit_map(PgRowMapAccess {
301 deserializer: self,
302 num_cols: self.row.columns().len(),
303 })
304 }
305
306 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
307 where
308 V: Visitor<'de>,
309 {
310 let raw_value = self.row.try_get_raw(self.index).map_err(DeError::custom)?;
311 let type_info = raw_value.type_info();
312 let type_name = type_info.name();
313
314 match type_name {
315 "TEXT[]" | "VARCHAR[]" => {
316 let seq_access = PgArraySeqAccess::<String>::new(raw_value)?;
317 visitor.visit_seq(seq_access)
318 }
319 "INT4[]" => {
320 let seq_access = PgArraySeqAccess::<i32>::new(raw_value)?;
321 visitor.visit_seq(seq_access)
322 }
323 "JSON[]" | "JSONB[]" => {
324 let seq_access = PgArraySeqAccess::<PgJson>::new(raw_value)?;
325 visitor.visit_seq(seq_access)
326 }
327 "BOOL[]" => {
328 let seq_access = PgArraySeqAccess::<bool>::new(raw_value)?;
329 visitor.visit_seq(seq_access)
330 }
331 _ => {
332 let seq_access = PgRowSeqAccess {
333 deserializer: self,
334 num_cols: self.row.columns().len(),
335 };
336
337 visitor.visit_seq(seq_access)
338 }
339 }
340 }
341
342 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
343 where
344 V: Visitor<'de>,
345 {
346 self.deserialize_seq(visitor)
347 }
348
349 fn deserialize_struct<V>(
350 self,
351 _name: &'static str,
352 fields: &'static [&'static str],
353 visitor: V,
354 ) -> Result<V::Value, Self::Error>
355 where
356 V: Visitor<'de>,
357 {
358 let raw_value = self.row.try_get_raw(self.index).map_err(DeError::custom)?;
359 let type_info = raw_value.type_info();
360 let type_name = type_info.name();
361
362 if type_name == "JSON" || type_name == "JSONB" {
363 let value = decode_raw_pg::<PgJson>(raw_value).map_err(|err| {
364 DeError::custom(format!("Failed to decode JSON/JSONB: {err}"))
365 })?;
366
367 if let serde_json::Value::Object(ref obj) = value.0 {
368 if fields.len() == 1 {
369 if obj.contains_key(fields[0]) {
371 return value.into_deserializer().deserialize_any(visitor);
373 } else {
374 let mut map = serde_json::Map::new();
376 map.insert(fields[0].to_owned(), value.0);
377 return map
378 .into_deserializer()
379 .deserialize_any(visitor)
380 .map_err(DeError::custom);
381 }
382 } else {
383 if fields.iter().all(|&field| obj.contains_key(field)) {
385 return value.into_deserializer().deserialize_any(visitor);
386 } else {
387 return Err(DeError::custom(format!(
388 "JSON object missing expected keys: expected {:?}, found keys {:?}",
389 fields,
390 obj.keys().collect::<Vec<_>>()
391 )));
392 }
393 }
394 } else {
395 return value.into_deserializer().deserialize_any(visitor);
397 }
398 }
399
400 self.deserialize_map(visitor)
402 }
403
404 forward_to_deserialize_any! {
406 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
407 bytes byte_buf unit unit_struct
408 tuple_struct enum identifier ignored_any
409 }
410 }
411
412 #[derive(Clone)]
414 pub(crate) struct PgValueDeserializer<'a> {
415 pub(crate) value: PgValueRef<'a>,
416 }
417
418 impl<'de, 'a> Deserializer<'de> for PgValueDeserializer<'a> {
419 type Error = DeError;
420
421 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
422 where
423 V: Visitor<'de>,
424 {
425 if self.value.is_null() {
426 visitor.visit_none()
427 } else {
428 visitor.visit_some(self)
429 }
430 }
431
432 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
433 where
434 V: Visitor<'de>,
435 {
436 if self.value.is_null() {
437 return visitor.visit_none();
438 }
439 let type_info = self.value.type_info();
440
441 let type_name = type_info.name();
442
443 match type_name {
444 "FLOAT4" => {
445 let v = decode_raw_pg::<f32>(self.value)?;
446 visitor.visit_f32(v)
447 }
448 "FLOAT8" => {
449 let v = decode_raw_pg::<f64>(self.value)?;
450 visitor.visit_f64(v)
451 }
452 "NUMERIC" => {
453 let numeric = decode_raw_pg::<rust_decimal::Decimal>(self.value)?;
454
455 let num: f64 = numeric
456 .try_into()
457 .map_err(|_| DeError::custom("Failed to parse Decimal as f64"))?;
458
459 visitor.visit_f64(num)
460 }
461 "INT8" => {
462 let v = decode_raw_pg::<i64>(self.value)?;
463 visitor.visit_i64(v)
464 }
465 "INT4" => {
466 let v = decode_raw_pg::<i32>(self.value)?;
467 visitor.visit_i32(v)
468 }
469 "INT2" => {
470 let v = decode_raw_pg::<i16>(self.value)?;
471 visitor.visit_i16(v)
472 }
473 "BOOL" => {
474 let v = decode_raw_pg::<bool>(self.value)?;
475 visitor.visit_bool(v)
476 }
477 "DATE" => {
478 let date = decode_raw_pg::<chrono::NaiveDate>(self.value)?;
479 visitor.visit_string(date.to_string())
480 }
481 "TIME" | "TIMETZ" => {
482 let time = decode_raw_pg::<chrono::NaiveTime>(self.value)?;
483 visitor.visit_string(time.to_string())
484 }
485 "TIMESTAMP" | "TIMESTAMPTZ" => {
486 let ts = decode_raw_pg::<chrono::DateTime<chrono::FixedOffset>>(self.value)?;
487 visitor.visit_string(ts.to_rfc3339())
488 }
489 "UUID" => {
490 let uuid = decode_raw_pg::<uuid::Uuid>(self.value)?;
491 visitor.visit_string(uuid.to_string())
492 }
493 "BYTEA" => {
494 let bytes = decode_raw_pg::<&[u8]>(self.value)?;
495 visitor.visit_bytes(bytes)
496 }
497 "INTERVAL" => {
498 let pg_interval =
499 decode_raw_pg::<sqlx::postgres::types::PgInterval>(self.value)?;
500 let secs = pg_interval.microseconds / 1_000_000;
501 let nanos = (pg_interval.microseconds % 1_000_000) * 1000;
502 let days_duration = chrono::Duration::days(pg_interval.days as i64);
503 let duration = chrono::Duration::seconds(secs)
504 + chrono::Duration::nanoseconds(nanos)
505 + days_duration;
506 visitor.visit_string(duration.to_string())
507 }
508 "CHAR" | "TEXT" => {
509 let s = decode_raw_pg::<String>(self.value)?;
510 visitor.visit_string(s)
511 }
512 "JSON" | "JSONB" => {
513 let value = decode_raw_pg::<PgJson>(self.value)?;
514
515 value.into_deserializer().deserialize_any(visitor)
516 }
517 _other => {
518 let as_string = decode_raw_pg::<String>(self.value.clone())?;
519 visitor.visit_string(as_string)
520 }
521 }
522 }
523
524 forward_to_deserialize_any! {
526 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
527 bytes byte_buf unit unit_struct newtype_struct struct
528 tuple_struct enum identifier ignored_any tuple seq map
529 }
530 }
531}
532
533mod json {
534 use serde::{
535 de::{self, value::Error as DeError, Deserializer, Error, IntoDeserializer},
536 forward_to_deserialize_any,
537 };
538 use serde_json::Value;
539 use sqlx::{
540 postgres::{PgTypeInfo, PgValueRef},
541 Postgres, TypeInfo, ValueRef,
542 };
543
544 #[derive(Debug)]
546 pub(crate) struct PgJson(pub(crate) serde_json::Value);
547
548 impl<'a> sqlx::Decode<'a, sqlx::Postgres> for PgJson {
549 fn decode(value: PgValueRef<'a>) -> Result<Self, sqlx::error::BoxDynError> {
550 let is_jsonb = match value.type_info().name() {
551 "JSON" => false,
552 "JSONB" => true,
553 other => unreachable!("Got {other} in PgJson"),
554 };
555
556 let mut bytes = value.as_bytes()?;
557
558 if is_jsonb {
560 if bytes.is_empty() || bytes[0] != 1 {
561 return Err("invalid JSONB header".into());
562 }
563
564 bytes = &bytes[1..]
566 };
567
568 let value = serde_json::from_slice(bytes)?;
569
570 Ok(PgJson(value))
571 }
572 }
573
574 impl sqlx::Type<Postgres> for PgJson {
575 fn type_info() -> PgTypeInfo {
576 PgTypeInfo::with_name("JSON")
577 }
578 }
579
580 pub struct PgJsonDeserializer {
581 value: Value,
582 }
583
584 impl<'de> Deserializer<'de> for PgJsonDeserializer {
585 type Error = DeError;
586
587 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
588 where
589 V: de::Visitor<'de>,
590 {
591 self.value.deserialize_any(visitor).map_err(DeError::custom)
593 }
594
595 forward_to_deserialize_any! {
596 bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
597 bytes byte_buf option unit unit_struct newtype_struct seq tuple
598 tuple_struct map struct enum identifier ignored_any
599 }
600 }
601
602 impl<'de> IntoDeserializer<'de> for PgJson {
603 type Deserializer = PgJsonDeserializer;
604
605 fn into_deserializer(self) -> Self::Deserializer {
606 PgJsonDeserializer { value: self.0 }
607 }
608 }
609}