spin_sdk/
pg4.rs

1// pg4 errors can be large, because they now include a breakdown of the PostgreSQL
2// error fields instead of just a string
3#![allow(clippy::result_large_err)]
4
5//! Postgres relational database storage.
6//!
7//! You can use the [`into()`](std::convert::Into) method to convert
8//! a Rust value into a [`ParameterValue`]. You can use the
9//! [`Decode`] trait to convert a [`DbValue`] to a suitable Rust type.
10//! The following table shows available conversions.
11//!
12//! # Types
13//!
14//! | Rust type               | WIT (db-value)                                | Postgres type(s)             |
15//! |-------------------------|-----------------------------------------------|----------------------------- |
16//! | `bool`                  | boolean(bool)                                 | BOOL                         |
17//! | `i16`                   | int16(s16)                                    | SMALLINT, SMALLSERIAL, INT2  |
18//! | `i32`                   | int32(s32)                                    | INT, SERIAL, INT4            |
19//! | `i64`                   | int64(s64)                                    | BIGINT, BIGSERIAL, INT8      |
20//! | `f32`                   | floating32(float32)                           | REAL, FLOAT4                 |
21//! | `f64`                   | floating64(float64)                           | DOUBLE PRECISION, FLOAT8     |
22//! | `String`                | str(string)                                   | VARCHAR, CHAR(N), TEXT       |
23//! | `Vec<u8>`               | binary(list\<u8\>)                            | BYTEA                        |
24//! | `chrono::NaiveDate`     | date(tuple<s32, u8, u8>)                      | DATE                         |
25//! | `chrono::NaiveTime`     | time(tuple<u8, u8, u8, u32>)                  | TIME                         |
26//! | `chrono::NaiveDateTime` | datetime(tuple<s32, u8, u8, u8, u8, u8, u32>) | TIMESTAMP                    |
27//! | `chrono::Duration`      | timestamp(s64)                                | BIGINT                       |
28//! | `uuid::Uuid`            | uuid(string)                                  | UUID                         |
29//! | `serde_json::Value`     | jsonb(list\<u8\>)                             | JSONB                        |
30//! | `serde::De/Serialize`   | jsonb(list\<u8\>)                             | JSONB                        |
31//! | `rust_decimal::Decimal` | decimal(string)                               | NUMERIC                      |
32//! | `postgres_range`        | range-int32(...), range-int64(...)            | INT4RANGE, INT8RANGE         |
33//! | lower/upper tuple       | range-decimal(...)                            | NUMERICRANGE                 |
34//! | `Vec<Option<...>>`      | array-int32(...), array-int64(...), array-str(...), array-decimal(...) | INT4[], INT8[], TEXT[], NUMERIC[] |
35//! | `pg4::Interval`         | interval(interval)                            | INTERVAL                     |
36
37/// An open connection to a PostgreSQL database.
38///
39/// # Examples
40///
41/// Load a set of rows from a local PostgreSQL database, and iterate over them.
42///
43/// ```no_run
44/// use spin_sdk::pg4::{Connection, Decode};
45///
46/// # fn main() -> anyhow::Result<()> {
47/// # let min_age = 0;
48/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb")?;
49///
50/// let query_result = db.query(
51///     "SELECT * FROM users WHERE age >= $1",
52///     &[min_age.into()]
53/// )?;
54///
55/// let name_index = query_result.columns.iter().position(|c| c.name == "name").unwrap();
56///
57/// for row in &query_result.rows {
58///     let name = String::decode(&row[name_index])?;
59///     println!("Found user {name}");
60/// }
61/// # Ok(())
62/// # }
63/// ```
64///
65/// Perform an aggregate (scalar) operation over a table. The result set
66/// contains a single column, with a single row.
67///
68/// ```no_run
69/// use spin_sdk::pg4::{Connection, Decode};
70///
71/// # fn main() -> anyhow::Result<()> {
72/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb")?;
73///
74/// let query_result = db.query("SELECT COUNT(*) FROM users", &[])?;
75///
76/// assert_eq!(1, query_result.columns.len());
77/// assert_eq!("count", query_result.columns[0].name);
78/// assert_eq!(1, query_result.rows.len());
79///
80/// let count = i64::decode(&query_result.rows[0][0])?;
81/// # Ok(())
82/// # }
83/// ```
84///
85/// Delete rows from a PostgreSQL table. This uses [Connection::execute()]
86/// instead of the `query` method.
87///
88/// ```no_run
89/// use spin_sdk::pg4::Connection;
90///
91/// # fn main() -> anyhow::Result<()> {
92/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb")?;
93///
94/// let rows_affected = db.execute(
95///     "DELETE FROM users WHERE name = $1",
96///     &["Baldrick".to_owned().into()]
97/// )?;
98/// # Ok(())
99/// # }
100/// ```
101#[doc(inline)]
102pub use super::wit::pg4::Connection;
103
104/// The result of a database query.
105///
106/// # Examples
107///
108/// Load a set of rows from a local PostgreSQL database, and iterate over them
109/// selecting one field from each. The columns collection allows you to find
110/// column indexes for column names; you can bypass this lookup if you name
111/// specific columns in the query.
112///
113/// ```no_run
114/// use spin_sdk::pg4::{Connection, Decode};
115///
116/// # fn main() -> anyhow::Result<()> {
117/// # let min_age = 0;
118/// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb")?;
119///
120/// let query_result = db.query(
121///     "SELECT * FROM users WHERE age >= $1",
122///     &[min_age.into()]
123/// )?;
124///
125/// let name_index = query_result.columns.iter().position(|c| c.name == "name").unwrap();
126///
127/// for row in &query_result.rows {
128///     let name = String::decode(&row[name_index])?;
129///     println!("Found user {name}");
130/// }
131/// # Ok(())
132/// # }
133/// ```
134pub use super::wit::pg4::RowSet;
135
136impl RowSet {
137    /// Get all the rows for this query result
138    pub fn rows(&self) -> impl Iterator<Item = Row<'_>> {
139        self.rows.iter().map(|r| Row {
140            columns: self.columns.as_slice(),
141            result: r,
142        })
143    }
144}
145
146/// A database row result.
147///
148/// There are two representations of a SQLite row in the SDK.  This type is useful for
149/// addressing elements by column name, and is obtained from the [RowSet::rows()] function.
150/// The [DbValue] vector representation is obtained from the [field@RowSet::rows] field, and provides
151/// index-based lookup or low-level access to row values via a vector.
152pub struct Row<'a> {
153    columns: &'a [super::wit::pg4::Column],
154    result: &'a [DbValue],
155}
156
157impl Row<'_> {
158    /// Get a value by its column name. The value is converted to the target type as per the
159    /// conversion table shown in the module documentation.
160    ///
161    /// This function returns None for both no such column _and_ failed conversion. You should use
162    /// it only if you do not need to address errors (that is, if you know that conversion should
163    /// never fail). If your code does not know the type in advance, use the raw [field@RowSet::rows] vector
164    /// instead of the `Row` wrapper to access the underlying [DbValue] enum: this will allow you to
165    /// determine the type and process it accordingly.
166    ///
167    /// Additionally, this function performs a name lookup each time it is called. If you are iterating
168    /// over a large number of rows, it's more efficient to use column indexes, either calculated or
169    /// statically known from the column order in the SQL.
170    ///
171    /// # Examples
172    ///
173    /// ```no_run
174    /// use spin_sdk::pg4::{Connection, DbValue};
175    ///
176    /// # fn main() -> anyhow::Result<()> {
177    /// # let user_id = 0;
178    /// let db = Connection::open("host=localhost user=postgres password=my_password dbname=mydb")?;
179    /// let query_result = db.query(
180    ///     "SELECT * FROM users WHERE id = $1",
181    ///     &[user_id.into()]
182    /// )?;
183    /// let user_row = query_result.rows().next().unwrap();
184    ///
185    /// let name = user_row.get::<String>("name").unwrap();
186    /// let age = user_row.get::<i16>("age").unwrap();
187    /// # Ok(())
188    /// # }
189    /// ```
190    pub fn get<T: Decode>(&self, column: &str) -> Option<T> {
191        let i = self.columns.iter().position(|c| c.name == column)?;
192        let db_value = self.result.get(i)?;
193        Decode::decode(db_value).ok()
194    }
195}
196
197#[doc(inline)]
198pub use super::wit::pg4::{Error as PgError, *};
199
200/// The PostgreSQL INTERVAL data type.
201pub use crate::pg4::Interval;
202
203use chrono::{Datelike, Timelike};
204
205/// A Postgres error
206#[derive(Debug, thiserror::Error)]
207pub enum Error {
208    /// Failed to deserialize [`DbValue`]
209    #[error("error value decoding: {0}")]
210    Decode(String),
211    /// Postgres query failed with an error
212    #[error(transparent)]
213    PgError(#[from] PgError),
214}
215
216/// A type that can be decoded from the database.
217pub trait Decode: Sized {
218    /// Decode a new value of this type using a [`DbValue`].
219    fn decode(value: &DbValue) -> Result<Self, Error>;
220}
221
222impl<T> Decode for Option<T>
223where
224    T: Decode,
225{
226    fn decode(value: &DbValue) -> Result<Self, Error> {
227        match value {
228            DbValue::DbNull => Ok(None),
229            v => Ok(Some(T::decode(v)?)),
230        }
231    }
232}
233
234impl Decode for bool {
235    fn decode(value: &DbValue) -> Result<Self, Error> {
236        match value {
237            DbValue::Boolean(boolean) => Ok(*boolean),
238            _ => Err(Error::Decode(format_decode_err("BOOL", value))),
239        }
240    }
241}
242
243impl Decode for i16 {
244    fn decode(value: &DbValue) -> Result<Self, Error> {
245        match value {
246            DbValue::Int16(n) => Ok(*n),
247            _ => Err(Error::Decode(format_decode_err("SMALLINT", value))),
248        }
249    }
250}
251
252impl Decode for i32 {
253    fn decode(value: &DbValue) -> Result<Self, Error> {
254        match value {
255            DbValue::Int32(n) => Ok(*n),
256            _ => Err(Error::Decode(format_decode_err("INT", value))),
257        }
258    }
259}
260
261impl Decode for i64 {
262    fn decode(value: &DbValue) -> Result<Self, Error> {
263        match value {
264            DbValue::Int64(n) => Ok(*n),
265            _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
266        }
267    }
268}
269
270impl Decode for f32 {
271    fn decode(value: &DbValue) -> Result<Self, Error> {
272        match value {
273            DbValue::Floating32(n) => Ok(*n),
274            _ => Err(Error::Decode(format_decode_err("REAL", value))),
275        }
276    }
277}
278
279impl Decode for f64 {
280    fn decode(value: &DbValue) -> Result<Self, Error> {
281        match value {
282            DbValue::Floating64(n) => Ok(*n),
283            _ => Err(Error::Decode(format_decode_err("DOUBLE PRECISION", value))),
284        }
285    }
286}
287
288impl Decode for Vec<u8> {
289    fn decode(value: &DbValue) -> Result<Self, Error> {
290        match value {
291            DbValue::Binary(n) => Ok(n.to_owned()),
292            _ => Err(Error::Decode(format_decode_err("BYTEA", value))),
293        }
294    }
295}
296
297impl Decode for String {
298    fn decode(value: &DbValue) -> Result<Self, Error> {
299        match value {
300            DbValue::Str(s) => Ok(s.to_owned()),
301            _ => Err(Error::Decode(format_decode_err(
302                "CHAR, VARCHAR, TEXT",
303                value,
304            ))),
305        }
306    }
307}
308
309impl Decode for chrono::NaiveDate {
310    fn decode(value: &DbValue) -> Result<Self, Error> {
311        match value {
312            DbValue::Date((year, month, day)) => {
313                let naive_date =
314                    chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
315                        .ok_or_else(|| {
316                            Error::Decode(format!(
317                                "invalid date y={}, m={}, d={}",
318                                year, month, day
319                            ))
320                        })?;
321                Ok(naive_date)
322            }
323            _ => Err(Error::Decode(format_decode_err("DATE", value))),
324        }
325    }
326}
327
328impl Decode for chrono::NaiveTime {
329    fn decode(value: &DbValue) -> Result<Self, Error> {
330        match value {
331            DbValue::Time((hour, minute, second, nanosecond)) => {
332                let naive_time = chrono::NaiveTime::from_hms_nano_opt(
333                    (*hour).into(),
334                    (*minute).into(),
335                    (*second).into(),
336                    *nanosecond,
337                )
338                .ok_or_else(|| {
339                    Error::Decode(format!(
340                        "invalid time {}:{}:{}:{}",
341                        hour, minute, second, nanosecond
342                    ))
343                })?;
344                Ok(naive_time)
345            }
346            _ => Err(Error::Decode(format_decode_err("TIME", value))),
347        }
348    }
349}
350
351impl Decode for chrono::NaiveDateTime {
352    fn decode(value: &DbValue) -> Result<Self, Error> {
353        match value {
354            DbValue::Datetime((year, month, day, hour, minute, second, nanosecond)) => {
355                let naive_date =
356                    chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
357                        .ok_or_else(|| {
358                            Error::Decode(format!(
359                                "invalid date y={}, m={}, d={}",
360                                year, month, day
361                            ))
362                        })?;
363                let naive_time = chrono::NaiveTime::from_hms_nano_opt(
364                    (*hour).into(),
365                    (*minute).into(),
366                    (*second).into(),
367                    *nanosecond,
368                )
369                .ok_or_else(|| {
370                    Error::Decode(format!(
371                        "invalid time {}:{}:{}:{}",
372                        hour, minute, second, nanosecond
373                    ))
374                })?;
375                let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
376                Ok(dt)
377            }
378            _ => Err(Error::Decode(format_decode_err("DATETIME", value))),
379        }
380    }
381}
382
383impl Decode for chrono::Duration {
384    fn decode(value: &DbValue) -> Result<Self, Error> {
385        match value {
386            DbValue::Timestamp(n) => Ok(chrono::Duration::seconds(*n)),
387            _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
388        }
389    }
390}
391
392#[cfg(feature = "postgres4-types")]
393impl Decode for uuid::Uuid {
394    fn decode(value: &DbValue) -> Result<Self, Error> {
395        match value {
396            DbValue::Uuid(s) => uuid::Uuid::parse_str(s).map_err(|e| Error::Decode(e.to_string())),
397            _ => Err(Error::Decode(format_decode_err("UUID", value))),
398        }
399    }
400}
401
402#[cfg(feature = "json")]
403impl Decode for serde_json::Value {
404    fn decode(value: &DbValue) -> Result<Self, Error> {
405        from_jsonb(value)
406    }
407}
408
409/// Convert a Postgres JSONB value to a `Deserialize`-able type.
410#[cfg(feature = "json")]
411pub fn from_jsonb<'a, T: serde::Deserialize<'a>>(value: &'a DbValue) -> Result<T, Error> {
412    match value {
413        DbValue::Jsonb(j) => serde_json::from_slice(j).map_err(|e| Error::Decode(e.to_string())),
414        _ => Err(Error::Decode(format_decode_err("JSONB", value))),
415    }
416}
417
418#[cfg(feature = "postgres4-types")]
419impl Decode for rust_decimal::Decimal {
420    fn decode(value: &DbValue) -> Result<Self, Error> {
421        match value {
422            DbValue::Decimal(s) => {
423                rust_decimal::Decimal::from_str_exact(s).map_err(|e| Error::Decode(e.to_string()))
424            }
425            _ => Err(Error::Decode(format_decode_err("NUMERIC", value))),
426        }
427    }
428}
429
430#[cfg(feature = "postgres4-types")]
431fn bound_type_from_wit(kind: RangeBoundKind) -> postgres_range::BoundType {
432    match kind {
433        RangeBoundKind::Inclusive => postgres_range::BoundType::Inclusive,
434        RangeBoundKind::Exclusive => postgres_range::BoundType::Exclusive,
435    }
436}
437
438#[cfg(feature = "postgres4-types")]
439impl Decode for postgres_range::Range<i32> {
440    fn decode(value: &DbValue) -> Result<Self, Error> {
441        match value {
442            DbValue::RangeInt32((lbound, ubound)) => {
443                let lower = lbound.map(|(value, kind)| {
444                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
445                });
446                let upper = ubound.map(|(value, kind)| {
447                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
448                });
449                Ok(postgres_range::Range::new(lower, upper))
450            }
451            _ => Err(Error::Decode(format_decode_err("INT4RANGE", value))),
452        }
453    }
454}
455
456#[cfg(feature = "postgres4-types")]
457impl Decode for postgres_range::Range<i64> {
458    fn decode(value: &DbValue) -> Result<Self, Error> {
459        match value {
460            DbValue::RangeInt64((lbound, ubound)) => {
461                let lower = lbound.map(|(value, kind)| {
462                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
463                });
464                let upper = ubound.map(|(value, kind)| {
465                    postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
466                });
467                Ok(postgres_range::Range::new(lower, upper))
468            }
469            _ => Err(Error::Decode(format_decode_err("INT8RANGE", value))),
470        }
471    }
472}
473
474// We can't use postgres_range::Range because rust_decimal::Decimal
475// is not Normalizable
476#[cfg(feature = "postgres4-types")]
477impl Decode
478    for (
479        Option<(rust_decimal::Decimal, RangeBoundKind)>,
480        Option<(rust_decimal::Decimal, RangeBoundKind)>,
481    )
482{
483    fn decode(value: &DbValue) -> Result<Self, Error> {
484        fn parse(
485            value: &str,
486            kind: RangeBoundKind,
487        ) -> Result<(rust_decimal::Decimal, RangeBoundKind), Error> {
488            let dec = rust_decimal::Decimal::from_str_exact(value)
489                .map_err(|e| Error::Decode(e.to_string()))?;
490            Ok((dec, kind))
491        }
492
493        match value {
494            DbValue::RangeDecimal((lbound, ubound)) => {
495                let lower = lbound
496                    .as_ref()
497                    .map(|(value, kind)| parse(value, *kind))
498                    .transpose()?;
499                let upper = ubound
500                    .as_ref()
501                    .map(|(value, kind)| parse(value, *kind))
502                    .transpose()?;
503                Ok((lower, upper))
504            }
505            _ => Err(Error::Decode(format_decode_err("NUMERICRANGE", value))),
506        }
507    }
508}
509
510// TODO: can we return a slice here? It seems like it should be possible but
511// I wasn't able to get the lifetimes to work with the trait
512impl Decode for Vec<Option<i32>> {
513    fn decode(value: &DbValue) -> Result<Self, Error> {
514        match value {
515            DbValue::ArrayInt32(a) => Ok(a.to_vec()),
516            _ => Err(Error::Decode(format_decode_err("INT4[]", value))),
517        }
518    }
519}
520
521impl Decode for Vec<Option<i64>> {
522    fn decode(value: &DbValue) -> Result<Self, Error> {
523        match value {
524            DbValue::ArrayInt64(a) => Ok(a.to_vec()),
525            _ => Err(Error::Decode(format_decode_err("INT8[]", value))),
526        }
527    }
528}
529
530impl Decode for Vec<Option<String>> {
531    fn decode(value: &DbValue) -> Result<Self, Error> {
532        match value {
533            DbValue::ArrayStr(a) => Ok(a.to_vec()),
534            _ => Err(Error::Decode(format_decode_err("TEXT[]", value))),
535        }
536    }
537}
538
539#[cfg(feature = "postgres4-types")]
540fn map_decimal(s: &Option<String>) -> Result<Option<rust_decimal::Decimal>, Error> {
541    s.as_ref()
542        .map(|s| rust_decimal::Decimal::from_str_exact(s))
543        .transpose()
544        .map_err(|e| Error::Decode(e.to_string()))
545}
546
547#[cfg(feature = "postgres4-types")]
548impl Decode for Vec<Option<rust_decimal::Decimal>> {
549    fn decode(value: &DbValue) -> Result<Self, Error> {
550        match value {
551            DbValue::ArrayDecimal(a) => {
552                let decs = a.iter().map(map_decimal).collect::<Result<_, _>>()?;
553                Ok(decs)
554            }
555            _ => Err(Error::Decode(format_decode_err("NUMERIC[]", value))),
556        }
557    }
558}
559
560impl Decode for Interval {
561    fn decode(value: &DbValue) -> Result<Self, Error> {
562        match value {
563            DbValue::Interval(i) => Ok(*i),
564            _ => Err(Error::Decode(format_decode_err("INTERVAL", value))),
565        }
566    }
567}
568
569macro_rules! impl_parameter_value_conversions {
570    ($($ty:ty => $id:ident),*) => {
571        $(
572            impl From<$ty> for ParameterValue {
573                fn from(v: $ty) -> ParameterValue {
574                    ParameterValue::$id(v)
575                }
576            }
577        )*
578    };
579}
580
581impl_parameter_value_conversions! {
582    i8 => Int8,
583    i16 => Int16,
584    i32 => Int32,
585    i64 => Int64,
586    f32 => Floating32,
587    f64 => Floating64,
588    bool => Boolean,
589    String => Str,
590    Vec<u8> => Binary,
591    Vec<Option<i32>> => ArrayInt32,
592    Vec<Option<i64>> => ArrayInt64,
593    Vec<Option<String>> => ArrayStr
594}
595
596impl From<chrono::NaiveDateTime> for ParameterValue {
597    fn from(v: chrono::NaiveDateTime) -> ParameterValue {
598        ParameterValue::Datetime((
599            v.year(),
600            v.month() as u8,
601            v.day() as u8,
602            v.hour() as u8,
603            v.minute() as u8,
604            v.second() as u8,
605            v.nanosecond(),
606        ))
607    }
608}
609
610impl From<chrono::NaiveTime> for ParameterValue {
611    fn from(v: chrono::NaiveTime) -> ParameterValue {
612        ParameterValue::Time((
613            v.hour() as u8,
614            v.minute() as u8,
615            v.second() as u8,
616            v.nanosecond(),
617        ))
618    }
619}
620
621impl From<chrono::NaiveDate> for ParameterValue {
622    fn from(v: chrono::NaiveDate) -> ParameterValue {
623        ParameterValue::Date((v.year(), v.month() as u8, v.day() as u8))
624    }
625}
626
627impl From<chrono::TimeDelta> for ParameterValue {
628    fn from(v: chrono::TimeDelta) -> ParameterValue {
629        ParameterValue::Timestamp(v.num_seconds())
630    }
631}
632
633#[cfg(feature = "postgres4-types")]
634impl From<uuid::Uuid> for ParameterValue {
635    fn from(v: uuid::Uuid) -> ParameterValue {
636        ParameterValue::Uuid(v.to_string())
637    }
638}
639
640#[cfg(feature = "json")]
641impl TryFrom<serde_json::Value> for ParameterValue {
642    type Error = serde_json::Error;
643
644    fn try_from(v: serde_json::Value) -> Result<ParameterValue, Self::Error> {
645        jsonb(&v)
646    }
647}
648
649/// Converts a `Serialize` value to a Postgres JSONB SQL parameter.
650#[cfg(feature = "json")]
651pub fn jsonb<T: serde::Serialize>(value: &T) -> Result<ParameterValue, serde_json::Error> {
652    let json = serde_json::to_vec(value)?;
653    Ok(ParameterValue::Jsonb(json))
654}
655
656#[cfg(feature = "postgres4-types")]
657impl From<rust_decimal::Decimal> for ParameterValue {
658    fn from(v: rust_decimal::Decimal) -> ParameterValue {
659        ParameterValue::Decimal(v.to_string())
660    }
661}
662
663// We cannot impl From<T: RangeBounds<...>> because Rust fears that some future
664// knave or rogue might one day add RangeBounds to NaiveDateTime. The best we can
665// do is therefore a helper function we can call from range Froms.
666#[allow(
667    clippy::type_complexity,
668    reason = "I sure hope 'blame Alex' works here too"
669)]
670fn range_bounds_to_wit<T, U>(
671    range: impl std::ops::RangeBounds<T>,
672    f: impl Fn(&T) -> U,
673) -> (Option<(U, RangeBoundKind)>, Option<(U, RangeBoundKind)>) {
674    (
675        range_bound_to_wit(range.start_bound(), &f),
676        range_bound_to_wit(range.end_bound(), &f),
677    )
678}
679
680fn range_bound_to_wit<T, U>(
681    bound: std::ops::Bound<&T>,
682    f: &dyn Fn(&T) -> U,
683) -> Option<(U, RangeBoundKind)> {
684    match bound {
685        std::ops::Bound::Included(v) => Some((f(v), RangeBoundKind::Inclusive)),
686        std::ops::Bound::Excluded(v) => Some((f(v), RangeBoundKind::Exclusive)),
687        std::ops::Bound::Unbounded => None,
688    }
689}
690
691#[cfg(feature = "postgres4-types")]
692fn pg_range_bound_to_wit<S: postgres_range::BoundSided, T: Copy>(
693    bound: &postgres_range::RangeBound<S, T>,
694) -> (T, RangeBoundKind) {
695    let kind = match &bound.type_ {
696        postgres_range::BoundType::Inclusive => RangeBoundKind::Inclusive,
697        postgres_range::BoundType::Exclusive => RangeBoundKind::Exclusive,
698    };
699    (bound.value, kind)
700}
701
702impl From<std::ops::Range<i32>> for ParameterValue {
703    fn from(v: std::ops::Range<i32>) -> ParameterValue {
704        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
705    }
706}
707
708impl From<std::ops::RangeInclusive<i32>> for ParameterValue {
709    fn from(v: std::ops::RangeInclusive<i32>) -> ParameterValue {
710        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
711    }
712}
713
714impl From<std::ops::RangeFrom<i32>> for ParameterValue {
715    fn from(v: std::ops::RangeFrom<i32>) -> ParameterValue {
716        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
717    }
718}
719
720impl From<std::ops::RangeTo<i32>> for ParameterValue {
721    fn from(v: std::ops::RangeTo<i32>) -> ParameterValue {
722        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
723    }
724}
725
726impl From<std::ops::RangeToInclusive<i32>> for ParameterValue {
727    fn from(v: std::ops::RangeToInclusive<i32>) -> ParameterValue {
728        ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
729    }
730}
731
732#[cfg(feature = "postgres4-types")]
733impl From<postgres_range::Range<i32>> for ParameterValue {
734    fn from(v: postgres_range::Range<i32>) -> ParameterValue {
735        let lbound = v.lower().map(pg_range_bound_to_wit);
736        let ubound = v.upper().map(pg_range_bound_to_wit);
737        ParameterValue::RangeInt32((lbound, ubound))
738    }
739}
740
741impl From<std::ops::Range<i64>> for ParameterValue {
742    fn from(v: std::ops::Range<i64>) -> ParameterValue {
743        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
744    }
745}
746
747impl From<std::ops::RangeInclusive<i64>> for ParameterValue {
748    fn from(v: std::ops::RangeInclusive<i64>) -> ParameterValue {
749        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
750    }
751}
752
753impl From<std::ops::RangeFrom<i64>> for ParameterValue {
754    fn from(v: std::ops::RangeFrom<i64>) -> ParameterValue {
755        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
756    }
757}
758
759impl From<std::ops::RangeTo<i64>> for ParameterValue {
760    fn from(v: std::ops::RangeTo<i64>) -> ParameterValue {
761        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
762    }
763}
764
765impl From<std::ops::RangeToInclusive<i64>> for ParameterValue {
766    fn from(v: std::ops::RangeToInclusive<i64>) -> ParameterValue {
767        ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
768    }
769}
770
771#[cfg(feature = "postgres4-types")]
772impl From<postgres_range::Range<i64>> for ParameterValue {
773    fn from(v: postgres_range::Range<i64>) -> ParameterValue {
774        let lbound = v.lower().map(pg_range_bound_to_wit);
775        let ubound = v.upper().map(pg_range_bound_to_wit);
776        ParameterValue::RangeInt64((lbound, ubound))
777    }
778}
779
780#[cfg(feature = "postgres4-types")]
781impl From<std::ops::Range<rust_decimal::Decimal>> for ParameterValue {
782    fn from(v: std::ops::Range<rust_decimal::Decimal>) -> ParameterValue {
783        ParameterValue::RangeDecimal(range_bounds_to_wit(v, |d| d.to_string()))
784    }
785}
786
787impl From<Vec<i32>> for ParameterValue {
788    fn from(v: Vec<i32>) -> ParameterValue {
789        ParameterValue::ArrayInt32(v.into_iter().map(Some).collect())
790    }
791}
792
793impl From<Vec<i64>> for ParameterValue {
794    fn from(v: Vec<i64>) -> ParameterValue {
795        ParameterValue::ArrayInt64(v.into_iter().map(Some).collect())
796    }
797}
798
799impl From<Vec<String>> for ParameterValue {
800    fn from(v: Vec<String>) -> ParameterValue {
801        ParameterValue::ArrayStr(v.into_iter().map(Some).collect())
802    }
803}
804
805#[cfg(feature = "postgres4-types")]
806impl From<Vec<Option<rust_decimal::Decimal>>> for ParameterValue {
807    fn from(v: Vec<Option<rust_decimal::Decimal>>) -> ParameterValue {
808        let strs = v
809            .into_iter()
810            .map(|optd| optd.map(|d| d.to_string()))
811            .collect();
812        ParameterValue::ArrayDecimal(strs)
813    }
814}
815
816#[cfg(feature = "postgres4-types")]
817impl From<Vec<rust_decimal::Decimal>> for ParameterValue {
818    fn from(v: Vec<rust_decimal::Decimal>) -> ParameterValue {
819        let strs = v.into_iter().map(|d| Some(d.to_string())).collect();
820        ParameterValue::ArrayDecimal(strs)
821    }
822}
823
824impl From<Interval> for ParameterValue {
825    fn from(v: Interval) -> ParameterValue {
826        ParameterValue::Interval(v)
827    }
828}
829
830impl<T: Into<ParameterValue>> From<Option<T>> for ParameterValue {
831    fn from(o: Option<T>) -> ParameterValue {
832        match o {
833            Some(v) => v.into(),
834            None => ParameterValue::DbNull,
835        }
836    }
837}
838
839fn format_decode_err(types: &str, value: &DbValue) -> String {
840    format!("Expected {} from the DB but got {:?}", types, value)
841}
842
843#[cfg(test)]
844mod tests {
845    use chrono::NaiveDateTime;
846
847    use super::*;
848
849    #[test]
850    fn boolean() {
851        assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
852        assert!(bool::decode(&DbValue::Int32(0)).is_err());
853        assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
854    }
855
856    #[test]
857    fn int16() {
858        assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
859        assert!(i16::decode(&DbValue::Int32(0)).is_err());
860        assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
861    }
862
863    #[test]
864    fn int32() {
865        assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
866        assert!(i32::decode(&DbValue::Boolean(false)).is_err());
867        assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
868    }
869
870    #[test]
871    fn int64() {
872        assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
873        assert!(i64::decode(&DbValue::Boolean(false)).is_err());
874        assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
875    }
876
877    #[test]
878    fn floating32() {
879        assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
880        assert!(f32::decode(&DbValue::Boolean(false)).is_err());
881        assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
882    }
883
884    #[test]
885    fn floating64() {
886        assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
887        assert!(f64::decode(&DbValue::Boolean(false)).is_err());
888        assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
889    }
890
891    #[test]
892    fn str() {
893        assert_eq!(
894            String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
895            String::from("foo")
896        );
897
898        assert!(String::decode(&DbValue::Int32(0)).is_err());
899        assert!(Option::<String>::decode(&DbValue::DbNull)
900            .unwrap()
901            .is_none());
902    }
903
904    #[test]
905    fn binary() {
906        assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
907        assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
908        assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
909            .unwrap()
910            .is_none());
911    }
912
913    #[test]
914    fn date() {
915        assert_eq!(
916            chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
917            chrono::NaiveDate::from_ymd_opt(1, 2, 4).unwrap()
918        );
919        assert_ne!(
920            chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
921            chrono::NaiveDate::from_ymd_opt(1, 2, 5).unwrap()
922        );
923        assert!(Option::<chrono::NaiveDate>::decode(&DbValue::DbNull)
924            .unwrap()
925            .is_none());
926    }
927
928    #[test]
929    fn time() {
930        assert_eq!(
931            chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
932            chrono::NaiveTime::from_hms_nano_opt(1, 2, 3, 4).unwrap()
933        );
934        assert_ne!(
935            chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
936            chrono::NaiveTime::from_hms_nano_opt(1, 2, 4, 5).unwrap()
937        );
938        assert!(Option::<chrono::NaiveTime>::decode(&DbValue::DbNull)
939            .unwrap()
940            .is_none());
941    }
942
943    #[test]
944    fn datetime() {
945        let date = chrono::NaiveDate::from_ymd_opt(1, 2, 3).unwrap();
946        let mut time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 7).unwrap();
947        assert_eq!(
948            chrono::NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
949            chrono::NaiveDateTime::new(date, time)
950        );
951
952        time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 8).unwrap();
953        assert_ne!(
954            NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
955            chrono::NaiveDateTime::new(date, time)
956        );
957        assert!(Option::<chrono::NaiveDateTime>::decode(&DbValue::DbNull)
958            .unwrap()
959            .is_none());
960    }
961
962    #[test]
963    fn timestamp() {
964        assert_eq!(
965            chrono::Duration::decode(&DbValue::Timestamp(1)).unwrap(),
966            chrono::Duration::seconds(1),
967        );
968        assert_ne!(
969            chrono::Duration::decode(&DbValue::Timestamp(2)).unwrap(),
970            chrono::Duration::seconds(1)
971        );
972        assert!(Option::<chrono::Duration>::decode(&DbValue::DbNull)
973            .unwrap()
974            .is_none());
975    }
976
977    #[test]
978    #[cfg(feature = "postgres4-types")]
979    fn uuid() {
980        let uuid_str = "12341234-1234-1234-1234-123412341234";
981        assert_eq!(
982            uuid::Uuid::try_parse(uuid_str).unwrap(),
983            uuid::Uuid::decode(&DbValue::Uuid(uuid_str.to_owned())).unwrap(),
984        );
985        assert!(Option::<uuid::Uuid>::decode(&DbValue::DbNull)
986            .unwrap()
987            .is_none());
988    }
989
990    #[derive(Debug, serde::Deserialize, PartialEq)]
991    struct JsonTest {
992        hello: String,
993    }
994
995    #[test]
996    #[cfg(feature = "json")]
997    fn jsonb() {
998        let json_val = serde_json::json!({
999            "hello": "world"
1000        });
1001        let dbval = DbValue::Jsonb(r#"{"hello":"world"}"#.into());
1002
1003        assert_eq!(json_val, serde_json::Value::decode(&dbval).unwrap(),);
1004
1005        let json_struct = JsonTest {
1006            hello: "world".to_owned(),
1007        };
1008        assert_eq!(json_struct, from_jsonb(&dbval).unwrap());
1009    }
1010
1011    #[test]
1012    #[cfg(feature = "postgres4-types")]
1013    fn ranges() {
1014        let i32_range = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
1015            Some((45, RangeBoundKind::Inclusive)),
1016            Some((89, RangeBoundKind::Exclusive)),
1017        )))
1018        .unwrap();
1019        assert_eq!(45, i32_range.lower().unwrap().value);
1020        assert_eq!(
1021            postgres_range::BoundType::Inclusive,
1022            i32_range.lower().unwrap().type_
1023        );
1024        assert_eq!(89, i32_range.upper().unwrap().value);
1025        assert_eq!(
1026            postgres_range::BoundType::Exclusive,
1027            i32_range.upper().unwrap().type_
1028        );
1029
1030        let i32_range_from = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
1031            Some((45, RangeBoundKind::Inclusive)),
1032            None,
1033        )))
1034        .unwrap();
1035        assert!(i32_range_from.upper().is_none());
1036
1037        let i64_range = postgres_range::Range::<i64>::decode(&DbValue::RangeInt64((
1038            Some((4567456745674567, RangeBoundKind::Inclusive)),
1039            Some((890189018901890189, RangeBoundKind::Exclusive)),
1040        )))
1041        .unwrap();
1042        assert_eq!(4567456745674567, i64_range.lower().unwrap().value);
1043        assert_eq!(890189018901890189, i64_range.upper().unwrap().value);
1044
1045        #[allow(clippy::type_complexity)]
1046        let (dec_lbound, dec_ubound): (
1047            Option<(rust_decimal::Decimal, RangeBoundKind)>,
1048            Option<(rust_decimal::Decimal, RangeBoundKind)>,
1049        ) = Decode::decode(&DbValue::RangeDecimal((
1050            Some(("4567.8901".to_owned(), RangeBoundKind::Inclusive)),
1051            Some(("8901.2345678901".to_owned(), RangeBoundKind::Exclusive)),
1052        )))
1053        .unwrap();
1054        assert_eq!(
1055            rust_decimal::Decimal::from_i128_with_scale(45678901, 4),
1056            dec_lbound.unwrap().0
1057        );
1058        assert_eq!(
1059            rust_decimal::Decimal::from_i128_with_scale(89012345678901, 10),
1060            dec_ubound.unwrap().0
1061        );
1062    }
1063
1064    #[test]
1065    #[cfg(feature = "postgres4-types")]
1066    fn arrays() {
1067        let v32 = vec![Some(123), None, Some(456)];
1068        let i32_arr = Vec::<Option<i32>>::decode(&DbValue::ArrayInt32(v32.clone())).unwrap();
1069        assert_eq!(v32, i32_arr);
1070
1071        let v64 = vec![Some(123), None, Some(456)];
1072        let i64_arr = Vec::<Option<i64>>::decode(&DbValue::ArrayInt64(v64.clone())).unwrap();
1073        assert_eq!(v64, i64_arr);
1074
1075        let vdec = vec![Some("1.23".to_owned()), None];
1076        let dec_arr =
1077            Vec::<Option<rust_decimal::Decimal>>::decode(&DbValue::ArrayDecimal(vdec)).unwrap();
1078        assert_eq!(
1079            vec![
1080                Some(rust_decimal::Decimal::from_i128_with_scale(123, 2)),
1081                None
1082            ],
1083            dec_arr
1084        );
1085
1086        let vstr = vec![Some("alice".to_owned()), None, Some("bob".to_owned())];
1087        let str_arr = Vec::<Option<String>>::decode(&DbValue::ArrayStr(vstr.clone())).unwrap();
1088        assert_eq!(vstr, str_arr);
1089    }
1090}