1#![allow(clippy::result_large_err)]
36
37use crate::wit_bindgen;
38use std::sync::Arc;
39
40#[doc(hidden)]
41pub mod wit {
45 #![allow(missing_docs)]
46 use crate::wit_bindgen;
47
48 wit_bindgen::generate!({
49 runtime_path: "crate::wit_bindgen::rt",
50 world: "spin-sdk-pg",
51 path: "wit",
52 generate_all,
53 });
54
55 pub use spin::postgres::postgres;
56}
57
58#[doc(inline)]
59pub use wit::postgres::{
60 Column, DbDataType, DbError, DbValue, Error as PgError, ParameterValue, QueryError,
61 RangeBoundKind,
62};
63
64pub use wit::postgres::Interval;
66
67use chrono::{Datelike, Timelike};
68
69pub struct Connection(wit::postgres::Connection);
137
138#[derive(Default)]
140pub struct OpenOptions {
141 pub ca_root: Option<Certificate>,
144}
145
146pub enum Certificate {
148 FilePath(String),
150 Text(String),
152}
153
154impl Certificate {
155 fn load(self) -> Result<String, Error> {
156 match self {
157 Certificate::FilePath(path) => std::fs::read_to_string(path)
158 .map_err(|e| Error::PgError(PgError::Other(e.to_string()))),
159 Certificate::Text(text) => Ok(text),
160 }
161 }
162}
163
164impl Connection {
165 pub async fn open(address: impl Into<String>) -> Result<Self, Error> {
173 let inner = wit::postgres::Connection::open_async(address.into()).await?;
174 Ok(Self(inner))
175 }
176
177 pub async fn open_with_options(
184 address: impl AsRef<str>,
185 options: OpenOptions,
186 ) -> Result<Self, Error> {
187 let builder = wit::postgres::ConnectionBuilder::new(address.as_ref());
188 let OpenOptions { ca_root } = options;
189
190 if let Some(ca_root) = ca_root {
191 let ca_root_text = ca_root.load()?;
192 builder.set_ca_root(&ca_root_text)?;
193 }
194
195 let inner = builder.build_async().await?;
196 Ok(Self(inner))
197 }
198
199 pub async fn query(
204 &self,
205 statement: impl Into<String>,
206 params: impl Into<Vec<ParameterValue>>,
207 ) -> Result<QueryResult, Error> {
208 let (columns, rows, result) = self.0.query_async(statement.into(), params.into()).await?;
209 Ok(QueryResult {
211 columns: Arc::new(columns),
212 rows,
213 result,
214 })
215 }
216
217 pub async fn execute(
222 &self,
223 statement: impl Into<String>,
224 params: impl Into<Vec<ParameterValue>>,
225 ) -> Result<u64, Error> {
226 self.0
227 .execute_async(statement.into(), params.into())
228 .await
229 .map_err(Error::PgError)
230 }
231
232 pub fn into_inner(self) -> wit::postgres::Connection {
234 self.0
235 }
236}
237
238pub struct QueryResult {
240 columns: Arc<Vec<Column>>,
241 rows: wit_bindgen::StreamReader<Vec<DbValue>>,
242 result: wit_bindgen::FutureReader<Result<(), PgError>>,
243}
244
245impl QueryResult {
246 pub fn columns(&self) -> &[Column] {
248 &self.columns
249 }
250
251 pub async fn next(&mut self) -> Option<Row> {
261 self.rows.next().await.map(|r| Row {
262 columns: self.columns.clone(),
263 result: r,
264 })
265 }
266
267 pub async fn result(self) -> Result<(), Error> {
269 self.result.await.map_err(Error::PgError)
270 }
271
272 pub async fn collect(mut self) -> Result<Vec<Row>, Error> {
277 let mut rows = vec![];
278 while let Some(row) = self.next().await {
279 rows.push(row);
280 }
281 self.result.await.map_err(Error::PgError)?;
282 Ok(rows)
283 }
284
285 pub fn rows(&mut self) -> &mut wit_bindgen::StreamReader<Vec<DbValue>> {
296 &mut self.rows
297 }
298
299 #[allow(
301 clippy::type_complexity,
302 reason = "sorry clippy that's just what the inner bits are"
303 )]
304 pub fn into_inner(
305 self,
306 ) -> (
307 Vec<Column>,
308 wit_bindgen::StreamReader<Vec<DbValue>>,
309 wit_bindgen::FutureReader<Result<(), PgError>>,
310 ) {
311 ((*self.columns).clone(), self.rows, self.result)
312 }
313}
314
315pub struct Row {
322 columns: Arc<Vec<wit::postgres::Column>>,
323 result: Vec<DbValue>,
324}
325
326impl Row {
327 pub fn get<T: Decode>(&self, column: &str) -> Option<T> {
361 let i = self.columns.iter().position(|c| c.name == column)?;
362 let db_value = self.result.get(i)?;
363 Decode::decode(db_value).ok()
364 }
365}
366
367impl std::ops::Index<usize> for Row {
368 type Output = DbValue;
369
370 fn index(&self, index: usize) -> &Self::Output {
371 &self.result[index]
372 }
373}
374
375#[derive(Debug, thiserror::Error)]
377pub enum Error {
378 #[error("error value decoding: {0}")]
380 Decode(String),
381 #[error(transparent)]
383 PgError(#[from] PgError),
384}
385
386pub trait Decode: Sized {
388 fn decode(value: &DbValue) -> Result<Self, Error>;
390}
391
392impl<T> Decode for Option<T>
393where
394 T: Decode,
395{
396 fn decode(value: &DbValue) -> Result<Self, Error> {
397 match value {
398 DbValue::DbNull => Ok(None),
399 v => Ok(Some(T::decode(v)?)),
400 }
401 }
402}
403
404impl Decode for bool {
405 fn decode(value: &DbValue) -> Result<Self, Error> {
406 match value {
407 DbValue::Boolean(boolean) => Ok(*boolean),
408 _ => Err(Error::Decode(format_decode_err("BOOL", value))),
409 }
410 }
411}
412
413impl Decode for i16 {
414 fn decode(value: &DbValue) -> Result<Self, Error> {
415 match value {
416 DbValue::Int16(n) => Ok(*n),
417 _ => Err(Error::Decode(format_decode_err("SMALLINT", value))),
418 }
419 }
420}
421
422impl Decode for i32 {
423 fn decode(value: &DbValue) -> Result<Self, Error> {
424 match value {
425 DbValue::Int32(n) => Ok(*n),
426 _ => Err(Error::Decode(format_decode_err("INT", value))),
427 }
428 }
429}
430
431impl Decode for i64 {
432 fn decode(value: &DbValue) -> Result<Self, Error> {
433 match value {
434 DbValue::Int64(n) => Ok(*n),
435 _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
436 }
437 }
438}
439
440impl Decode for f32 {
441 fn decode(value: &DbValue) -> Result<Self, Error> {
442 match value {
443 DbValue::Floating32(n) => Ok(*n),
444 _ => Err(Error::Decode(format_decode_err("REAL", value))),
445 }
446 }
447}
448
449impl Decode for f64 {
450 fn decode(value: &DbValue) -> Result<Self, Error> {
451 match value {
452 DbValue::Floating64(n) => Ok(*n),
453 _ => Err(Error::Decode(format_decode_err("DOUBLE PRECISION", value))),
454 }
455 }
456}
457
458impl Decode for Vec<u8> {
459 fn decode(value: &DbValue) -> Result<Self, Error> {
460 match value {
461 DbValue::Binary(n) => Ok(n.to_owned()),
462 _ => Err(Error::Decode(format_decode_err("BYTEA", value))),
463 }
464 }
465}
466
467impl Decode for String {
468 fn decode(value: &DbValue) -> Result<Self, Error> {
469 match value {
470 DbValue::Str(s) => Ok(s.to_owned()),
471 _ => Err(Error::Decode(format_decode_err(
472 "CHAR, VARCHAR, TEXT",
473 value,
474 ))),
475 }
476 }
477}
478
479impl Decode for chrono::NaiveDate {
480 fn decode(value: &DbValue) -> Result<Self, Error> {
481 match value {
482 DbValue::Date((year, month, day)) => {
483 let naive_date =
484 chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
485 .ok_or_else(|| {
486 Error::Decode(format!(
487 "invalid date y={}, m={}, d={}",
488 year, month, day
489 ))
490 })?;
491 Ok(naive_date)
492 }
493 _ => Err(Error::Decode(format_decode_err("DATE", value))),
494 }
495 }
496}
497
498impl Decode for chrono::NaiveTime {
499 fn decode(value: &DbValue) -> Result<Self, Error> {
500 match value {
501 DbValue::Time((hour, minute, second, nanosecond)) => {
502 let naive_time = chrono::NaiveTime::from_hms_nano_opt(
503 (*hour).into(),
504 (*minute).into(),
505 (*second).into(),
506 *nanosecond,
507 )
508 .ok_or_else(|| {
509 Error::Decode(format!(
510 "invalid time {}:{}:{}:{}",
511 hour, minute, second, nanosecond
512 ))
513 })?;
514 Ok(naive_time)
515 }
516 _ => Err(Error::Decode(format_decode_err("TIME", value))),
517 }
518 }
519}
520
521impl Decode for chrono::NaiveDateTime {
522 fn decode(value: &DbValue) -> Result<Self, Error> {
523 match value {
524 DbValue::Datetime((year, month, day, hour, minute, second, nanosecond)) => {
525 let naive_date =
526 chrono::NaiveDate::from_ymd_opt(*year, (*month).into(), (*day).into())
527 .ok_or_else(|| {
528 Error::Decode(format!(
529 "invalid date y={}, m={}, d={}",
530 year, month, day
531 ))
532 })?;
533 let naive_time = chrono::NaiveTime::from_hms_nano_opt(
534 (*hour).into(),
535 (*minute).into(),
536 (*second).into(),
537 *nanosecond,
538 )
539 .ok_or_else(|| {
540 Error::Decode(format!(
541 "invalid time {}:{}:{}:{}",
542 hour, minute, second, nanosecond
543 ))
544 })?;
545 let dt = chrono::NaiveDateTime::new(naive_date, naive_time);
546 Ok(dt)
547 }
548 _ => Err(Error::Decode(format_decode_err("DATETIME", value))),
549 }
550 }
551}
552
553impl Decode for chrono::Duration {
554 fn decode(value: &DbValue) -> Result<Self, Error> {
555 match value {
556 DbValue::Timestamp(n) => Ok(chrono::Duration::seconds(*n)),
557 _ => Err(Error::Decode(format_decode_err("BIGINT", value))),
558 }
559 }
560}
561
562#[cfg(feature = "postgres4-types")]
563impl Decode for uuid::Uuid {
564 fn decode(value: &DbValue) -> Result<Self, Error> {
565 match value {
566 DbValue::Uuid(s) => uuid::Uuid::parse_str(s).map_err(|e| Error::Decode(e.to_string())),
567 _ => Err(Error::Decode(format_decode_err("UUID", value))),
568 }
569 }
570}
571
572#[cfg(feature = "json")]
573impl Decode for serde_json::Value {
574 fn decode(value: &DbValue) -> Result<Self, Error> {
575 from_jsonb(value)
576 }
577}
578
579#[cfg(feature = "json")]
581pub fn from_jsonb<'a, T: serde::Deserialize<'a>>(value: &'a DbValue) -> Result<T, Error> {
582 match value {
583 DbValue::Jsonb(j) => serde_json::from_slice(j).map_err(|e| Error::Decode(e.to_string())),
584 _ => Err(Error::Decode(format_decode_err("JSONB", value))),
585 }
586}
587
588#[cfg(feature = "postgres4-types")]
589impl Decode for rust_decimal::Decimal {
590 fn decode(value: &DbValue) -> Result<Self, Error> {
591 match value {
592 DbValue::Decimal(s) => {
593 rust_decimal::Decimal::from_str_exact(s).map_err(|e| Error::Decode(e.to_string()))
594 }
595 _ => Err(Error::Decode(format_decode_err("NUMERIC", value))),
596 }
597 }
598}
599
600#[cfg(feature = "postgres4-types")]
601fn bound_type_from_wit(kind: RangeBoundKind) -> postgres_range::BoundType {
602 match kind {
603 RangeBoundKind::Inclusive => postgres_range::BoundType::Inclusive,
604 RangeBoundKind::Exclusive => postgres_range::BoundType::Exclusive,
605 }
606}
607
608#[cfg(feature = "postgres4-types")]
609impl Decode for postgres_range::Range<i32> {
610 fn decode(value: &DbValue) -> Result<Self, Error> {
611 match value {
612 DbValue::RangeInt32((lbound, ubound)) => {
613 let lower = lbound.map(|(value, kind)| {
614 postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
615 });
616 let upper = ubound.map(|(value, kind)| {
617 postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
618 });
619 Ok(postgres_range::Range::new(lower, upper))
620 }
621 _ => Err(Error::Decode(format_decode_err("INT4RANGE", value))),
622 }
623 }
624}
625
626#[cfg(feature = "postgres4-types")]
627impl Decode for postgres_range::Range<i64> {
628 fn decode(value: &DbValue) -> Result<Self, Error> {
629 match value {
630 DbValue::RangeInt64((lbound, ubound)) => {
631 let lower = lbound.map(|(value, kind)| {
632 postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
633 });
634 let upper = ubound.map(|(value, kind)| {
635 postgres_range::RangeBound::new(value, bound_type_from_wit(kind))
636 });
637 Ok(postgres_range::Range::new(lower, upper))
638 }
639 _ => Err(Error::Decode(format_decode_err("INT8RANGE", value))),
640 }
641 }
642}
643
644#[cfg(feature = "postgres4-types")]
647impl Decode
648 for (
649 Option<(rust_decimal::Decimal, RangeBoundKind)>,
650 Option<(rust_decimal::Decimal, RangeBoundKind)>,
651 )
652{
653 fn decode(value: &DbValue) -> Result<Self, Error> {
654 fn parse(
655 value: &str,
656 kind: RangeBoundKind,
657 ) -> Result<(rust_decimal::Decimal, RangeBoundKind), Error> {
658 let dec = rust_decimal::Decimal::from_str_exact(value)
659 .map_err(|e| Error::Decode(e.to_string()))?;
660 Ok((dec, kind))
661 }
662
663 match value {
664 DbValue::RangeDecimal((lbound, ubound)) => {
665 let lower = lbound
666 .as_ref()
667 .map(|(value, kind)| parse(value, *kind))
668 .transpose()?;
669 let upper = ubound
670 .as_ref()
671 .map(|(value, kind)| parse(value, *kind))
672 .transpose()?;
673 Ok((lower, upper))
674 }
675 _ => Err(Error::Decode(format_decode_err("NUMERICRANGE", value))),
676 }
677 }
678}
679
680impl Decode for Vec<Option<i32>> {
683 fn decode(value: &DbValue) -> Result<Self, Error> {
684 match value {
685 DbValue::ArrayInt32(a) => Ok(a.to_vec()),
686 _ => Err(Error::Decode(format_decode_err("INT4[]", value))),
687 }
688 }
689}
690
691impl Decode for Vec<Option<i64>> {
692 fn decode(value: &DbValue) -> Result<Self, Error> {
693 match value {
694 DbValue::ArrayInt64(a) => Ok(a.to_vec()),
695 _ => Err(Error::Decode(format_decode_err("INT8[]", value))),
696 }
697 }
698}
699
700impl Decode for Vec<Option<String>> {
701 fn decode(value: &DbValue) -> Result<Self, Error> {
702 match value {
703 DbValue::ArrayStr(a) => Ok(a.to_vec()),
704 _ => Err(Error::Decode(format_decode_err("TEXT[]", value))),
705 }
706 }
707}
708
709#[cfg(feature = "postgres4-types")]
710fn map_decimal(s: &Option<String>) -> Result<Option<rust_decimal::Decimal>, Error> {
711 s.as_ref()
712 .map(|s| rust_decimal::Decimal::from_str_exact(s))
713 .transpose()
714 .map_err(|e| Error::Decode(e.to_string()))
715}
716
717#[cfg(feature = "postgres4-types")]
718impl Decode for Vec<Option<rust_decimal::Decimal>> {
719 fn decode(value: &DbValue) -> Result<Self, Error> {
720 match value {
721 DbValue::ArrayDecimal(a) => {
722 let decs = a.iter().map(map_decimal).collect::<Result<_, _>>()?;
723 Ok(decs)
724 }
725 _ => Err(Error::Decode(format_decode_err("NUMERIC[]", value))),
726 }
727 }
728}
729
730impl Decode for Interval {
731 fn decode(value: &DbValue) -> Result<Self, Error> {
732 match value {
733 DbValue::Interval(i) => Ok(*i),
734 _ => Err(Error::Decode(format_decode_err("INTERVAL", value))),
735 }
736 }
737}
738
739macro_rules! impl_parameter_value_conversions {
740 ($($ty:ty => $id:ident),*) => {
741 $(
742 impl From<$ty> for ParameterValue {
743 fn from(v: $ty) -> ParameterValue {
744 ParameterValue::$id(v)
745 }
746 }
747 )*
748 };
749}
750
751impl_parameter_value_conversions! {
752 i8 => Int8,
753 i16 => Int16,
754 i32 => Int32,
755 i64 => Int64,
756 f32 => Floating32,
757 f64 => Floating64,
758 bool => Boolean,
759 String => Str,
760 Vec<u8> => Binary,
761 Vec<Option<i32>> => ArrayInt32,
762 Vec<Option<i64>> => ArrayInt64,
763 Vec<Option<String>> => ArrayStr
764}
765
766impl From<chrono::NaiveDateTime> for ParameterValue {
767 fn from(v: chrono::NaiveDateTime) -> ParameterValue {
768 ParameterValue::Datetime((
769 v.year(),
770 v.month() as u8,
771 v.day() as u8,
772 v.hour() as u8,
773 v.minute() as u8,
774 v.second() as u8,
775 v.nanosecond(),
776 ))
777 }
778}
779
780impl From<chrono::NaiveTime> for ParameterValue {
781 fn from(v: chrono::NaiveTime) -> ParameterValue {
782 ParameterValue::Time((
783 v.hour() as u8,
784 v.minute() as u8,
785 v.second() as u8,
786 v.nanosecond(),
787 ))
788 }
789}
790
791impl From<chrono::NaiveDate> for ParameterValue {
792 fn from(v: chrono::NaiveDate) -> ParameterValue {
793 ParameterValue::Date((v.year(), v.month() as u8, v.day() as u8))
794 }
795}
796
797impl From<chrono::TimeDelta> for ParameterValue {
798 fn from(v: chrono::TimeDelta) -> ParameterValue {
799 ParameterValue::Timestamp(v.num_seconds())
800 }
801}
802
803#[cfg(feature = "postgres4-types")]
804impl From<uuid::Uuid> for ParameterValue {
805 fn from(v: uuid::Uuid) -> ParameterValue {
806 ParameterValue::Uuid(v.to_string())
807 }
808}
809
810#[cfg(feature = "json")]
811impl TryFrom<serde_json::Value> for ParameterValue {
812 type Error = serde_json::Error;
813
814 fn try_from(v: serde_json::Value) -> Result<ParameterValue, Self::Error> {
815 jsonb(&v)
816 }
817}
818
819#[cfg(feature = "json")]
821pub fn jsonb<T: serde::Serialize>(value: &T) -> Result<ParameterValue, serde_json::Error> {
822 let json = serde_json::to_vec(value)?;
823 Ok(ParameterValue::Jsonb(json))
824}
825
826#[cfg(feature = "postgres4-types")]
827impl From<rust_decimal::Decimal> for ParameterValue {
828 fn from(v: rust_decimal::Decimal) -> ParameterValue {
829 ParameterValue::Decimal(v.to_string())
830 }
831}
832
833#[allow(
837 clippy::type_complexity,
838 reason = "I sure hope 'blame Alex' works here too"
839)]
840fn range_bounds_to_wit<T, U>(
841 range: impl std::ops::RangeBounds<T>,
842 f: impl Fn(&T) -> U,
843) -> (Option<(U, RangeBoundKind)>, Option<(U, RangeBoundKind)>) {
844 (
845 range_bound_to_wit(range.start_bound(), &f),
846 range_bound_to_wit(range.end_bound(), &f),
847 )
848}
849
850fn range_bound_to_wit<T, U>(
851 bound: std::ops::Bound<&T>,
852 f: &dyn Fn(&T) -> U,
853) -> Option<(U, RangeBoundKind)> {
854 match bound {
855 std::ops::Bound::Included(v) => Some((f(v), RangeBoundKind::Inclusive)),
856 std::ops::Bound::Excluded(v) => Some((f(v), RangeBoundKind::Exclusive)),
857 std::ops::Bound::Unbounded => None,
858 }
859}
860
861#[cfg(feature = "postgres4-types")]
862fn pg_range_bound_to_wit<S: postgres_range::BoundSided, T: Copy>(
863 bound: &postgres_range::RangeBound<S, T>,
864) -> (T, RangeBoundKind) {
865 let kind = match &bound.type_ {
866 postgres_range::BoundType::Inclusive => RangeBoundKind::Inclusive,
867 postgres_range::BoundType::Exclusive => RangeBoundKind::Exclusive,
868 };
869 (bound.value, kind)
870}
871
872impl From<std::ops::Range<i32>> for ParameterValue {
873 fn from(v: std::ops::Range<i32>) -> ParameterValue {
874 ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
875 }
876}
877
878impl From<std::ops::RangeInclusive<i32>> for ParameterValue {
879 fn from(v: std::ops::RangeInclusive<i32>) -> ParameterValue {
880 ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
881 }
882}
883
884impl From<std::ops::RangeFrom<i32>> for ParameterValue {
885 fn from(v: std::ops::RangeFrom<i32>) -> ParameterValue {
886 ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
887 }
888}
889
890impl From<std::ops::RangeTo<i32>> for ParameterValue {
891 fn from(v: std::ops::RangeTo<i32>) -> ParameterValue {
892 ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
893 }
894}
895
896impl From<std::ops::RangeToInclusive<i32>> for ParameterValue {
897 fn from(v: std::ops::RangeToInclusive<i32>) -> ParameterValue {
898 ParameterValue::RangeInt32(range_bounds_to_wit(v, |n| *n))
899 }
900}
901
902#[cfg(feature = "postgres4-types")]
903impl From<postgres_range::Range<i32>> for ParameterValue {
904 fn from(v: postgres_range::Range<i32>) -> ParameterValue {
905 let lbound = v.lower().map(pg_range_bound_to_wit);
906 let ubound = v.upper().map(pg_range_bound_to_wit);
907 ParameterValue::RangeInt32((lbound, ubound))
908 }
909}
910
911impl From<std::ops::Range<i64>> for ParameterValue {
912 fn from(v: std::ops::Range<i64>) -> ParameterValue {
913 ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
914 }
915}
916
917impl From<std::ops::RangeInclusive<i64>> for ParameterValue {
918 fn from(v: std::ops::RangeInclusive<i64>) -> ParameterValue {
919 ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
920 }
921}
922
923impl From<std::ops::RangeFrom<i64>> for ParameterValue {
924 fn from(v: std::ops::RangeFrom<i64>) -> ParameterValue {
925 ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
926 }
927}
928
929impl From<std::ops::RangeTo<i64>> for ParameterValue {
930 fn from(v: std::ops::RangeTo<i64>) -> ParameterValue {
931 ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
932 }
933}
934
935impl From<std::ops::RangeToInclusive<i64>> for ParameterValue {
936 fn from(v: std::ops::RangeToInclusive<i64>) -> ParameterValue {
937 ParameterValue::RangeInt64(range_bounds_to_wit(v, |n| *n))
938 }
939}
940
941#[cfg(feature = "postgres4-types")]
942impl From<postgres_range::Range<i64>> for ParameterValue {
943 fn from(v: postgres_range::Range<i64>) -> ParameterValue {
944 let lbound = v.lower().map(pg_range_bound_to_wit);
945 let ubound = v.upper().map(pg_range_bound_to_wit);
946 ParameterValue::RangeInt64((lbound, ubound))
947 }
948}
949
950#[cfg(feature = "postgres4-types")]
951impl From<std::ops::Range<rust_decimal::Decimal>> for ParameterValue {
952 fn from(v: std::ops::Range<rust_decimal::Decimal>) -> ParameterValue {
953 ParameterValue::RangeDecimal(range_bounds_to_wit(v, |d| d.to_string()))
954 }
955}
956
957impl From<Vec<i32>> for ParameterValue {
958 fn from(v: Vec<i32>) -> ParameterValue {
959 ParameterValue::ArrayInt32(v.into_iter().map(Some).collect())
960 }
961}
962
963impl From<Vec<i64>> for ParameterValue {
964 fn from(v: Vec<i64>) -> ParameterValue {
965 ParameterValue::ArrayInt64(v.into_iter().map(Some).collect())
966 }
967}
968
969impl From<Vec<String>> for ParameterValue {
970 fn from(v: Vec<String>) -> ParameterValue {
971 ParameterValue::ArrayStr(v.into_iter().map(Some).collect())
972 }
973}
974
975#[cfg(feature = "postgres4-types")]
976impl From<Vec<Option<rust_decimal::Decimal>>> for ParameterValue {
977 fn from(v: Vec<Option<rust_decimal::Decimal>>) -> ParameterValue {
978 let strs = v
979 .into_iter()
980 .map(|optd| optd.map(|d| d.to_string()))
981 .collect();
982 ParameterValue::ArrayDecimal(strs)
983 }
984}
985
986#[cfg(feature = "postgres4-types")]
987impl From<Vec<rust_decimal::Decimal>> for ParameterValue {
988 fn from(v: Vec<rust_decimal::Decimal>) -> ParameterValue {
989 let strs = v.into_iter().map(|d| Some(d.to_string())).collect();
990 ParameterValue::ArrayDecimal(strs)
991 }
992}
993
994impl From<Interval> for ParameterValue {
995 fn from(v: Interval) -> ParameterValue {
996 ParameterValue::Interval(v)
997 }
998}
999
1000impl<T: Into<ParameterValue>> From<Option<T>> for ParameterValue {
1001 fn from(o: Option<T>) -> ParameterValue {
1002 match o {
1003 Some(v) => v.into(),
1004 None => ParameterValue::DbNull,
1005 }
1006 }
1007}
1008
1009fn format_decode_err(types: &str, value: &DbValue) -> String {
1010 format!("Expected {} from the DB but got {:?}", types, value)
1011}
1012
1013#[cfg(test)]
1014mod tests {
1015 use chrono::NaiveDateTime;
1016
1017 use super::*;
1018
1019 #[test]
1020 fn boolean() {
1021 assert!(bool::decode(&DbValue::Boolean(true)).unwrap());
1022 assert!(bool::decode(&DbValue::Int32(0)).is_err());
1023 assert!(Option::<bool>::decode(&DbValue::DbNull).unwrap().is_none());
1024 }
1025
1026 #[test]
1027 fn int16() {
1028 assert_eq!(i16::decode(&DbValue::Int16(0)).unwrap(), 0);
1029 assert!(i16::decode(&DbValue::Int32(0)).is_err());
1030 assert!(Option::<i16>::decode(&DbValue::DbNull).unwrap().is_none());
1031 }
1032
1033 #[test]
1034 fn int32() {
1035 assert_eq!(i32::decode(&DbValue::Int32(0)).unwrap(), 0);
1036 assert!(i32::decode(&DbValue::Boolean(false)).is_err());
1037 assert!(Option::<i32>::decode(&DbValue::DbNull).unwrap().is_none());
1038 }
1039
1040 #[test]
1041 fn int64() {
1042 assert_eq!(i64::decode(&DbValue::Int64(0)).unwrap(), 0);
1043 assert!(i64::decode(&DbValue::Boolean(false)).is_err());
1044 assert!(Option::<i64>::decode(&DbValue::DbNull).unwrap().is_none());
1045 }
1046
1047 #[test]
1048 fn floating32() {
1049 assert!(f32::decode(&DbValue::Floating32(0.0)).is_ok());
1050 assert!(f32::decode(&DbValue::Boolean(false)).is_err());
1051 assert!(Option::<f32>::decode(&DbValue::DbNull).unwrap().is_none());
1052 }
1053
1054 #[test]
1055 fn floating64() {
1056 assert!(f64::decode(&DbValue::Floating64(0.0)).is_ok());
1057 assert!(f64::decode(&DbValue::Boolean(false)).is_err());
1058 assert!(Option::<f64>::decode(&DbValue::DbNull).unwrap().is_none());
1059 }
1060
1061 #[test]
1062 fn str() {
1063 assert_eq!(
1064 String::decode(&DbValue::Str(String::from("foo"))).unwrap(),
1065 String::from("foo")
1066 );
1067
1068 assert!(String::decode(&DbValue::Int32(0)).is_err());
1069 assert!(Option::<String>::decode(&DbValue::DbNull)
1070 .unwrap()
1071 .is_none());
1072 }
1073
1074 #[test]
1075 fn binary() {
1076 assert!(Vec::<u8>::decode(&DbValue::Binary(vec![0, 0])).is_ok());
1077 assert!(Vec::<u8>::decode(&DbValue::Boolean(false)).is_err());
1078 assert!(Option::<Vec<u8>>::decode(&DbValue::DbNull)
1079 .unwrap()
1080 .is_none());
1081 }
1082
1083 #[test]
1084 fn date() {
1085 assert_eq!(
1086 chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
1087 chrono::NaiveDate::from_ymd_opt(1, 2, 4).unwrap()
1088 );
1089 assert_ne!(
1090 chrono::NaiveDate::decode(&DbValue::Date((1, 2, 4))).unwrap(),
1091 chrono::NaiveDate::from_ymd_opt(1, 2, 5).unwrap()
1092 );
1093 assert!(Option::<chrono::NaiveDate>::decode(&DbValue::DbNull)
1094 .unwrap()
1095 .is_none());
1096 }
1097
1098 #[test]
1099 fn time() {
1100 assert_eq!(
1101 chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
1102 chrono::NaiveTime::from_hms_nano_opt(1, 2, 3, 4).unwrap()
1103 );
1104 assert_ne!(
1105 chrono::NaiveTime::decode(&DbValue::Time((1, 2, 3, 4))).unwrap(),
1106 chrono::NaiveTime::from_hms_nano_opt(1, 2, 4, 5).unwrap()
1107 );
1108 assert!(Option::<chrono::NaiveTime>::decode(&DbValue::DbNull)
1109 .unwrap()
1110 .is_none());
1111 }
1112
1113 #[test]
1114 fn datetime() {
1115 let date = chrono::NaiveDate::from_ymd_opt(1, 2, 3).unwrap();
1116 let mut time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 7).unwrap();
1117 assert_eq!(
1118 chrono::NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
1119 chrono::NaiveDateTime::new(date, time)
1120 );
1121
1122 time = chrono::NaiveTime::from_hms_nano_opt(4, 5, 6, 8).unwrap();
1123 assert_ne!(
1124 NaiveDateTime::decode(&DbValue::Datetime((1, 2, 3, 4, 5, 6, 7))).unwrap(),
1125 chrono::NaiveDateTime::new(date, time)
1126 );
1127 assert!(Option::<chrono::NaiveDateTime>::decode(&DbValue::DbNull)
1128 .unwrap()
1129 .is_none());
1130 }
1131
1132 #[test]
1133 fn timestamp() {
1134 assert_eq!(
1135 chrono::Duration::decode(&DbValue::Timestamp(1)).unwrap(),
1136 chrono::Duration::seconds(1),
1137 );
1138 assert_ne!(
1139 chrono::Duration::decode(&DbValue::Timestamp(2)).unwrap(),
1140 chrono::Duration::seconds(1)
1141 );
1142 assert!(Option::<chrono::Duration>::decode(&DbValue::DbNull)
1143 .unwrap()
1144 .is_none());
1145 }
1146
1147 #[test]
1148 #[cfg(feature = "postgres4-types")]
1149 fn uuid() {
1150 let uuid_str = "12341234-1234-1234-1234-123412341234";
1151 assert_eq!(
1152 uuid::Uuid::try_parse(uuid_str).unwrap(),
1153 uuid::Uuid::decode(&DbValue::Uuid(uuid_str.to_owned())).unwrap(),
1154 );
1155 assert!(Option::<uuid::Uuid>::decode(&DbValue::DbNull)
1156 .unwrap()
1157 .is_none());
1158 }
1159
1160 #[derive(Debug, serde::Deserialize, PartialEq)]
1161 struct JsonTest {
1162 hello: String,
1163 }
1164
1165 #[test]
1166 #[cfg(feature = "json")]
1167 fn jsonb() {
1168 let json_val = serde_json::json!({
1169 "hello": "world"
1170 });
1171 let dbval = DbValue::Jsonb(r#"{"hello":"world"}"#.into());
1172
1173 assert_eq!(json_val, serde_json::Value::decode(&dbval).unwrap(),);
1174
1175 let json_struct = JsonTest {
1176 hello: "world".to_owned(),
1177 };
1178 assert_eq!(json_struct, from_jsonb(&dbval).unwrap());
1179 }
1180
1181 #[test]
1182 #[cfg(feature = "postgres4-types")]
1183 fn ranges() {
1184 let i32_range = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
1185 Some((45, RangeBoundKind::Inclusive)),
1186 Some((89, RangeBoundKind::Exclusive)),
1187 )))
1188 .unwrap();
1189 assert_eq!(45, i32_range.lower().unwrap().value);
1190 assert_eq!(
1191 postgres_range::BoundType::Inclusive,
1192 i32_range.lower().unwrap().type_
1193 );
1194 assert_eq!(89, i32_range.upper().unwrap().value);
1195 assert_eq!(
1196 postgres_range::BoundType::Exclusive,
1197 i32_range.upper().unwrap().type_
1198 );
1199
1200 let i32_range_from = postgres_range::Range::<i32>::decode(&DbValue::RangeInt32((
1201 Some((45, RangeBoundKind::Inclusive)),
1202 None,
1203 )))
1204 .unwrap();
1205 assert!(i32_range_from.upper().is_none());
1206
1207 let i64_range = postgres_range::Range::<i64>::decode(&DbValue::RangeInt64((
1208 Some((4567456745674567, RangeBoundKind::Inclusive)),
1209 Some((890189018901890189, RangeBoundKind::Exclusive)),
1210 )))
1211 .unwrap();
1212 assert_eq!(4567456745674567, i64_range.lower().unwrap().value);
1213 assert_eq!(890189018901890189, i64_range.upper().unwrap().value);
1214
1215 #[allow(clippy::type_complexity)]
1216 let (dec_lbound, dec_ubound): (
1217 Option<(rust_decimal::Decimal, RangeBoundKind)>,
1218 Option<(rust_decimal::Decimal, RangeBoundKind)>,
1219 ) = Decode::decode(&DbValue::RangeDecimal((
1220 Some(("4567.8901".to_owned(), RangeBoundKind::Inclusive)),
1221 Some(("8901.2345678901".to_owned(), RangeBoundKind::Exclusive)),
1222 )))
1223 .unwrap();
1224 assert_eq!(
1225 rust_decimal::Decimal::from_i128_with_scale(45678901, 4),
1226 dec_lbound.unwrap().0
1227 );
1228 assert_eq!(
1229 rust_decimal::Decimal::from_i128_with_scale(89012345678901, 10),
1230 dec_ubound.unwrap().0
1231 );
1232 }
1233
1234 #[test]
1235 #[cfg(feature = "postgres4-types")]
1236 fn arrays() {
1237 let v32 = vec![Some(123), None, Some(456)];
1238 let i32_arr = Vec::<Option<i32>>::decode(&DbValue::ArrayInt32(v32.clone())).unwrap();
1239 assert_eq!(v32, i32_arr);
1240
1241 let v64 = vec![Some(123), None, Some(456)];
1242 let i64_arr = Vec::<Option<i64>>::decode(&DbValue::ArrayInt64(v64.clone())).unwrap();
1243 assert_eq!(v64, i64_arr);
1244
1245 let vdec = vec![Some("1.23".to_owned()), None];
1246 let dec_arr =
1247 Vec::<Option<rust_decimal::Decimal>>::decode(&DbValue::ArrayDecimal(vdec)).unwrap();
1248 assert_eq!(
1249 vec![
1250 Some(rust_decimal::Decimal::from_i128_with_scale(123, 2)),
1251 None
1252 ],
1253 dec_arr
1254 );
1255
1256 let vstr = vec![Some("alice".to_owned()), None, Some("bob".to_owned())];
1257 let str_arr = Vec::<Option<String>>::decode(&DbValue::ArrayStr(vstr.clone())).unwrap();
1258 assert_eq!(vstr, str_arr);
1259 }
1260}