serde_sqlx/
lib.rs

1use deserializers::PgRowDeserializer;
2use serde::de::Error;
3use serde::de::{value::Error as DeError, Deserialize};
4
5use sqlx::postgres::{PgRow, PgValueRef};
6
7/// Convenience function: deserialize a PgRow into any T that implements Deserialize
8pub fn from_pg_row<T>(row: PgRow) -> Result<T, DeError>
9where
10    T: for<'de> Deserialize<'de>,
11{
12    let deserializer = PgRowDeserializer::new(&row);
13    T::deserialize(deserializer)
14}
15
16fn decode_raw_pg<'a, T>(raw_value: PgValueRef<'a>) -> Result<T, DeError>
17where
18    T: sqlx::Decode<'a, sqlx::Postgres>,
19{
20    T::decode(raw_value).map_err(|err| {
21        DeError::custom(format!(
22            "Failed to decode {} value: {:?}",
23            std::any::type_name::<T>(),
24            err,
25        ))
26    })
27}
28
29mod seq_access {
30    use std::fmt::Debug;
31
32    use serde::de::{value::Error as DeError, DeserializeSeed, SeqAccess, Visitor};
33    use serde::ser::Error as _;
34    use serde::{de, forward_to_deserialize_any};
35    use sqlx::{postgres::PgValueRef, Row};
36
37    use crate::{
38        decode_raw_pg,
39        deserializers::{PgRowDeserializer, PgValueDeserializer},
40    };
41
42    /// A SeqAccess implementation that iterates over the row’s columns
43    pub(crate) struct PgRowSeqAccess<'a> {
44        pub(crate) deserializer: PgRowDeserializer<'a>,
45        pub(crate) num_cols: usize,
46    }
47
48    impl<'de, 'a> SeqAccess<'de> for PgRowSeqAccess<'a> {
49        type Error = DeError;
50
51        fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
52        where
53            T: DeserializeSeed<'de>,
54        {
55            if self.deserializer.index < self.num_cols {
56                let value = self
57                    .deserializer
58                    .row
59                    .try_get_raw(self.deserializer.index)
60                    .map_err(DeError::custom)?;
61
62                // Create a PgValueDeserializer for the current column.
63                let pg_value_deserializer = PgValueDeserializer { value };
64
65                self.deserializer.index += 1;
66
67                // Deserialize the value and return it wrapped in Some.
68                seed.deserialize(pg_value_deserializer).map(Some)
69            } else {
70                Ok(None)
71            }
72        }
73    }
74
75    use serde::de::IntoDeserializer;
76
77    /// SeqAccess implementation for Postgres arrays
78    /// It decodes a raw Postgres array, such as TEXT[] into a `Vec<Option<T>>` and
79    /// then yields each element during deserialization
80    pub struct PgArraySeqAccess<T> {
81        iter: std::vec::IntoIter<Option<T>>,
82    }
83
84    impl<'de, 'a, T> PgArraySeqAccess<T>
85    where
86        T: sqlx::Decode<'a, sqlx::Postgres> + Debug,
87    {
88        pub fn new(value: PgValueRef<'a>) -> Result<Self, DeError>
89        where
90            Vec<Option<T>>: sqlx::Decode<'a, sqlx::Postgres> + Debug,
91        {
92            let vec: Vec<Option<T>> = decode_raw_pg(value)?;
93
94            Ok(PgArraySeqAccess {
95                iter: vec.into_iter(),
96            })
97        }
98    }
99
100    impl<'de, T> SeqAccess<'de> for PgArraySeqAccess<T>
101    where
102        T: IntoDeserializer<'de, DeError>,
103    {
104        type Error = DeError;
105
106        fn next_element_seed<U>(&mut self, seed: U) -> Result<Option<U::Value>, Self::Error>
107        where
108            U: DeserializeSeed<'de>,
109        {
110            let Some(value) = self.iter.next() else {
111                return Ok(None);
112            };
113
114            seed.deserialize(PgArrayElementDeserializer { value })
115                .map(Some)
116        }
117    }
118
119    /// Yet another deserializer, this time to handles Options
120    struct PgArrayElementDeserializer<T> {
121        pub value: Option<T>,
122    }
123
124    impl<'de, T> de::Deserializer<'de> for PgArrayElementDeserializer<T>
125    where
126        T: IntoDeserializer<'de, DeError>,
127    {
128        type Error = DeError;
129
130        fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
131        where
132            V: Visitor<'de>,
133        {
134            match self.value {
135                Some(v) => visitor.visit_some(v.into_deserializer()),
136                None => visitor.visit_none(),
137            }
138        }
139
140        fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
141        where
142            V: Visitor<'de>,
143        {
144            match self.value {
145                Some(v) => v.into_deserializer().deserialize_any(visitor),
146                None => Err(DeError::custom(
147                    "unexpected null in non-optional array element",
148                )),
149            }
150        }
151
152        forward_to_deserialize_any! {
153            bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
154            bytes byte_buf unit unit_struct newtype_struct seq tuple tuple_struct
155            map struct enum identifier ignored_any
156        }
157    }
158}
159
160mod map_access {
161    use serde::de::{self, value::Error as DeError, IntoDeserializer, MapAccess};
162    use serde::ser::Error as _;
163
164    use sqlx::{Column, Row};
165
166    use crate::deserializers::{PgRowDeserializer, PgValueDeserializer};
167
168    pub(crate) struct PgRowMapAccess<'a> {
169        pub(crate) deserializer: PgRowDeserializer<'a>,
170        pub(crate) num_cols: usize,
171    }
172
173    impl<'de, 'a> MapAccess<'de> for PgRowMapAccess<'a> {
174        type Error = DeError;
175
176        fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
177        where
178            K: de::DeserializeSeed<'de>,
179        {
180            if self.deserializer.index < self.num_cols {
181                let col_name = self.deserializer.row.columns()[self.deserializer.index].name();
182                // Use the column name as the key
183                seed.deserialize(col_name.into_deserializer()).map(Some)
184            } else {
185                Ok(None)
186            }
187        }
188
189        fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
190        where
191            V: de::DeserializeSeed<'de>,
192        {
193            let value = self
194                .deserializer
195                .row
196                .try_get_raw(self.deserializer.index)
197                .map_err(DeError::custom)?;
198            let pg_type_deserializer = PgValueDeserializer { value };
199
200            self.deserializer.index += 1;
201
202            seed.deserialize(pg_type_deserializer)
203        }
204    }
205}
206
207mod deserializers {
208    use crate::decode_raw_pg;
209    use crate::json::PgJson;
210    use crate::map_access::PgRowMapAccess;
211    use crate::seq_access::{PgArraySeqAccess, PgRowSeqAccess};
212    use serde::de::{value::Error as DeError, Deserializer, Visitor};
213    use serde::de::{Error as _, IntoDeserializer};
214    use serde::forward_to_deserialize_any;
215    use sqlx::postgres::{PgRow, PgValueRef};
216    use sqlx::{Row, TypeInfo, ValueRef};
217
218    #[derive(Clone, Copy)]
219    pub struct PgRowDeserializer<'a> {
220        pub(crate) row: &'a PgRow,
221        pub(crate) index: usize,
222    }
223
224    impl<'a> PgRowDeserializer<'a> {
225        pub fn new(row: &'a PgRow) -> Self {
226            PgRowDeserializer { row, index: 0 }
227        }
228
229        #[allow(unused)]
230        pub fn is_json(&self) -> bool {
231            self.row.try_get_raw(0).map_or(false, |value| {
232                matches!(value.type_info().name(), "JSON" | "JSONB")
233            })
234        }
235    }
236
237    impl<'de, 'a> Deserializer<'de> for PgRowDeserializer<'a> {
238        type Error = DeError;
239
240        fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
241        where
242            V: Visitor<'de>,
243        {
244            let raw_value = self.row.try_get_raw(0).map_err(DeError::custom)?;
245
246            if raw_value.is_null() {
247                visitor.visit_none()
248            } else {
249                visitor.visit_some(self)
250            }
251        }
252
253        fn deserialize_newtype_struct<V>(
254            self,
255            _name: &'static str,
256            visitor: V,
257        ) -> Result<V::Value, Self::Error>
258        where
259            V: Visitor<'de>,
260        {
261            visitor.visit_newtype_struct(self)
262        }
263
264        fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265        where
266            V: Visitor<'de>,
267        {
268            match self.row.columns().len() {
269                0 => return visitor.visit_unit(),
270                1 => {}
271                _n => {
272                    return self.deserialize_seq(visitor);
273                }
274            };
275
276            let raw_value = self.row.try_get_raw(self.index).map_err(DeError::custom)?;
277            let type_info = raw_value.type_info();
278            let type_name = type_info.name();
279
280            if raw_value.is_null() {
281                return visitor.visit_none();
282            }
283
284            // If this is a BOOL[], TEXT[], etc
285            if type_name.ends_with("[]") {
286                return self.deserialize_seq(visitor);
287            }
288
289            // Direct all "basic" types down to `PgValueDeserializer`
290            let deserializer = PgValueDeserializer { value: raw_value };
291
292            deserializer.deserialize_any(visitor)
293        }
294
295        /// We treat the row as a map (each column is a key/value pair)
296        fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
297        where
298            V: Visitor<'de>,
299        {
300            visitor.visit_map(PgRowMapAccess {
301                deserializer: self,
302                num_cols: self.row.columns().len(),
303            })
304        }
305
306        fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
307        where
308            V: Visitor<'de>,
309        {
310            let raw_value = self.row.try_get_raw(self.index).map_err(DeError::custom)?;
311            let type_info = raw_value.type_info();
312            let type_name = type_info.name();
313
314            match type_name {
315                "TEXT[]" | "VARCHAR[]" => {
316                    let seq_access = PgArraySeqAccess::<String>::new(raw_value)?;
317                    visitor.visit_seq(seq_access)
318                }
319                "INT4[]" => {
320                    let seq_access = PgArraySeqAccess::<i32>::new(raw_value)?;
321                    visitor.visit_seq(seq_access)
322                }
323                "JSON[]" | "JSONB[]" => {
324                    let seq_access = PgArraySeqAccess::<PgJson>::new(raw_value)?;
325                    visitor.visit_seq(seq_access)
326                }
327                "BOOL[]" => {
328                    let seq_access = PgArraySeqAccess::<bool>::new(raw_value)?;
329                    visitor.visit_seq(seq_access)
330                }
331                _ => {
332                    let seq_access = PgRowSeqAccess {
333                        deserializer: self,
334                        num_cols: self.row.columns().len(),
335                    };
336
337                    visitor.visit_seq(seq_access)
338                }
339            }
340        }
341
342        fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
343        where
344            V: Visitor<'de>,
345        {
346            self.deserialize_seq(visitor)
347        }
348
349        fn deserialize_struct<V>(
350            self,
351            _name: &'static str,
352            fields: &'static [&'static str],
353            visitor: V,
354        ) -> Result<V::Value, Self::Error>
355        where
356            V: Visitor<'de>,
357        {
358            let raw_value = self.row.try_get_raw(self.index).map_err(DeError::custom)?;
359            let type_info = raw_value.type_info();
360            let type_name = type_info.name();
361
362            if type_name == "JSON" || type_name == "JSONB" {
363                let value = decode_raw_pg::<PgJson>(raw_value).map_err(|err| {
364                    DeError::custom(format!("Failed to decode JSON/JSONB: {err}"))
365                })?;
366
367                if let serde_json::Value::Object(ref obj) = value.0 {
368                    if fields.len() == 1 {
369                        // If there's only one expected field, check if the object already contains it.
370                        if obj.contains_key(fields[0]) {
371                            // If so, we can deserialize directly.
372                            return value.into_deserializer().deserialize_any(visitor);
373                        } else {
374                            // Otherwise, wrap the object in a new map keyed by that field name.
375                            let mut map = serde_json::Map::new();
376                            map.insert(fields[0].to_owned(), value.0);
377                            return map
378                                .into_deserializer()
379                                .deserialize_any(visitor)
380                                .map_err(DeError::custom);
381                        }
382                    } else {
383                        // For multiple expected fields, ensure the JSON object already contains all of them.
384                        if fields.iter().all(|&field| obj.contains_key(field)) {
385                            return value.into_deserializer().deserialize_any(visitor);
386                        } else {
387                            return Err(DeError::custom(format!(
388                                "JSON object missing expected keys: expected {:?}, found keys {:?}",
389                                fields,
390                                obj.keys().collect::<Vec<_>>()
391                            )));
392                        }
393                    }
394                } else {
395                    // For non-object JSON values, delegate directly.
396                    return value.into_deserializer().deserialize_any(visitor);
397                }
398            }
399
400            // Fallback for non-JSON types.
401            self.deserialize_map(visitor)
402        }
403
404        // For other types, forward to deserialize_any.
405        forward_to_deserialize_any! {
406            bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
407            bytes byte_buf unit unit_struct
408            tuple_struct enum identifier ignored_any
409        }
410    }
411
412    /// An "inner" deserializer
413    #[derive(Clone)]
414    pub(crate) struct PgValueDeserializer<'a> {
415        pub(crate) value: PgValueRef<'a>,
416    }
417
418    impl<'de, 'a> Deserializer<'de> for PgValueDeserializer<'a> {
419        type Error = DeError;
420
421        fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
422        where
423            V: Visitor<'de>,
424        {
425            if self.value.is_null() {
426                visitor.visit_none()
427            } else {
428                visitor.visit_some(self)
429            }
430        }
431
432        fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
433        where
434            V: Visitor<'de>,
435        {
436            if self.value.is_null() {
437                return visitor.visit_none();
438            }
439            let type_info = self.value.type_info();
440
441            let type_name = type_info.name();
442
443            match type_name {
444                "FLOAT4" => {
445                    let v = decode_raw_pg::<f32>(self.value)?;
446                    visitor.visit_f32(v)
447                }
448                "FLOAT8" => {
449                    let v = decode_raw_pg::<f64>(self.value)?;
450                    visitor.visit_f64(v)
451                }
452                "NUMERIC" => {
453                    let numeric = decode_raw_pg::<rust_decimal::Decimal>(self.value)?;
454
455                    let num: f64 = numeric
456                        .try_into()
457                        .map_err(|_| DeError::custom("Failed to parse Decimal as f64"))?;
458
459                    visitor.visit_f64(num)
460                }
461                "INT8" => {
462                    let v = decode_raw_pg::<i64>(self.value)?;
463                    visitor.visit_i64(v)
464                }
465                "INT4" => {
466                    let v = decode_raw_pg::<i32>(self.value)?;
467                    visitor.visit_i32(v)
468                }
469                "INT2" => {
470                    let v = decode_raw_pg::<i16>(self.value)?;
471                    visitor.visit_i16(v)
472                }
473                "BOOL" => {
474                    let v = decode_raw_pg::<bool>(self.value)?;
475                    visitor.visit_bool(v)
476                }
477                "DATE" => {
478                    let date = decode_raw_pg::<chrono::NaiveDate>(self.value)?;
479                    visitor.visit_string(date.to_string())
480                }
481                "TIME" | "TIMETZ" => {
482                    let time = decode_raw_pg::<chrono::NaiveTime>(self.value)?;
483                    visitor.visit_string(time.to_string())
484                }
485                "TIMESTAMP" | "TIMESTAMPTZ" => {
486                    let ts = decode_raw_pg::<chrono::DateTime<chrono::FixedOffset>>(self.value)?;
487                    visitor.visit_string(ts.to_rfc3339())
488                }
489                "UUID" => {
490                    let uuid = decode_raw_pg::<uuid::Uuid>(self.value)?;
491                    visitor.visit_string(uuid.to_string())
492                }
493                "BYTEA" => {
494                    let bytes = decode_raw_pg::<&[u8]>(self.value)?;
495                    visitor.visit_bytes(bytes)
496                }
497                "INTERVAL" => {
498                    let pg_interval =
499                        decode_raw_pg::<sqlx::postgres::types::PgInterval>(self.value)?;
500                    let secs = pg_interval.microseconds / 1_000_000;
501                    let nanos = (pg_interval.microseconds % 1_000_000) * 1000;
502                    let days_duration = chrono::Duration::days(pg_interval.days as i64);
503                    let duration = chrono::Duration::seconds(secs)
504                        + chrono::Duration::nanoseconds(nanos)
505                        + days_duration;
506                    visitor.visit_string(duration.to_string())
507                }
508                "CHAR" | "TEXT" => {
509                    let s = decode_raw_pg::<String>(self.value)?;
510                    visitor.visit_string(s)
511                }
512                "JSON" | "JSONB" => {
513                    let value = decode_raw_pg::<PgJson>(self.value)?;
514
515                    value.into_deserializer().deserialize_any(visitor)
516                }
517                _other => {
518                    let as_string = decode_raw_pg::<String>(self.value.clone())?;
519                    visitor.visit_string(as_string)
520                }
521            }
522        }
523
524        // For other types, forward to deserialize_any.
525        forward_to_deserialize_any! {
526            bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
527            bytes byte_buf unit unit_struct newtype_struct struct
528            tuple_struct enum identifier ignored_any tuple seq map
529        }
530    }
531}
532
533mod json {
534    use serde::{
535        de::{self, value::Error as DeError, Deserializer, Error, IntoDeserializer},
536        forward_to_deserialize_any,
537    };
538    use serde_json::Value;
539    use sqlx::{
540        postgres::{PgTypeInfo, PgValueRef},
541        Postgres, TypeInfo, ValueRef,
542    };
543
544    /// Decodes Postgres' JSON or JSONB into serde_json::Value
545    #[derive(Debug)]
546    pub(crate) struct PgJson(pub(crate) serde_json::Value);
547
548    impl<'a> sqlx::Decode<'a, sqlx::Postgres> for PgJson {
549        fn decode(value: PgValueRef<'a>) -> Result<Self, sqlx::error::BoxDynError> {
550            let is_jsonb = match value.type_info().name() {
551                "JSON" => false,
552                "JSONB" => true,
553                other => unreachable!("Got {other} in PgJson"),
554            };
555
556            let mut bytes = value.as_bytes()?;
557
558            // For JSONB, the first byte is a version (should be 1)
559            if is_jsonb {
560                if bytes.is_empty() || bytes[0] != 1 {
561                    return Err("invalid JSONB header".into());
562                }
563
564                // Skip the version byte
565                bytes = &bytes[1..]
566            };
567
568            let value = serde_json::from_slice(bytes)?;
569
570            Ok(PgJson(value))
571        }
572    }
573
574    impl sqlx::Type<Postgres> for PgJson {
575        fn type_info() -> PgTypeInfo {
576            PgTypeInfo::with_name("JSON")
577        }
578    }
579
580    pub struct PgJsonDeserializer {
581        value: Value,
582    }
583
584    impl<'de> Deserializer<'de> for PgJsonDeserializer {
585        type Error = DeError;
586
587        fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
588        where
589            V: de::Visitor<'de>,
590        {
591            // Delegate to serde_json::Value's own Deserializer
592            self.value.deserialize_any(visitor).map_err(DeError::custom)
593        }
594
595        forward_to_deserialize_any! {
596            bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string
597            bytes byte_buf option unit unit_struct newtype_struct seq tuple
598            tuple_struct map struct enum identifier ignored_any
599        }
600    }
601
602    impl<'de> IntoDeserializer<'de> for PgJson {
603        type Deserializer = PgJsonDeserializer;
604
605        fn into_deserializer(self) -> Self::Deserializer {
606            PgJsonDeserializer { value: self.0 }
607        }
608    }
609}