sqlx_postgres/
bind_iter.rs

1use crate::{type_info::PgType, PgArgumentBuffer, PgHasArrayType, PgTypeInfo, Postgres};
2use core::cell::Cell;
3use sqlx_core::{
4    database::Database,
5    encode::{Encode, IsNull},
6    error::BoxDynError,
7    types::Type,
8};
9
10// not exported but pub because it is used in the extension trait
11pub struct PgBindIter<I>(Cell<Option<I>>);
12
13/// Iterator extension trait enabling iterators to encode arrays in Postgres.
14///
15/// Because of the blanket impl of `PgHasArrayType` for all references
16/// we can borrow instead of needing to clone or copy in the iterators
17/// and it still works
18///
19/// Previously, 3 separate arrays would be needed in this example which
20/// requires iterating 3 times to collect items into the array and then
21/// iterating over them again to encode.
22///
23/// This now requires only iterating over the array once for each field
24/// while using less memory giving both speed and memory usage improvements
25/// along with allowing much more flexibility in the underlying collection.
26///
27/// ```rust,no_run
28/// # async fn test_bind_iter() -> Result<(), sqlx::error::BoxDynError> {
29/// # use sqlx::types::chrono::{DateTime, Utc};
30/// # use sqlx::Connection;
31/// # fn people() -> &'static [Person] {
32/// #   &[]
33/// # }
34/// # let mut conn = <sqlx::Postgres as sqlx::Database>::Connection::connect("dummyurl").await?;
35/// use sqlx::postgres::PgBindIterExt;
36///
37/// #[derive(sqlx::FromRow)]
38/// struct Person {
39///     id: i64,
40///     name: String,
41///     birthdate: DateTime<Utc>,
42/// }
43///
44/// # let people: &[Person] = people();
45/// sqlx::query("insert into person(id, name, birthdate) select * from unnest($1, $2, $3)")
46///     .bind(people.iter().map(|p| p.id).bind_iter())
47///     .bind(people.iter().map(|p| &p.name).bind_iter())
48///     .bind(people.iter().map(|p| &p.birthdate).bind_iter())
49///     .execute(&mut conn)
50///     .await?;
51///
52/// # Ok(())
53/// # }
54/// ```
55pub trait PgBindIterExt: Iterator + Sized {
56    fn bind_iter(self) -> PgBindIter<Self>;
57}
58
59impl<I: Iterator + Sized> PgBindIterExt for I {
60    fn bind_iter(self) -> PgBindIter<I> {
61        PgBindIter(Cell::new(Some(self)))
62    }
63}
64
65impl<I> Type<Postgres> for PgBindIter<I>
66where
67    I: Iterator,
68    <I as Iterator>::Item: Type<Postgres> + PgHasArrayType,
69{
70    fn type_info() -> <Postgres as Database>::TypeInfo {
71        <I as Iterator>::Item::array_type_info()
72    }
73    fn compatible(ty: &PgTypeInfo) -> bool {
74        <I as Iterator>::Item::array_compatible(ty)
75    }
76}
77
78impl<'q, I> PgBindIter<I>
79where
80    I: Iterator,
81    <I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
82{
83    fn encode_inner(
84        // need ownership to iterate
85        mut iter: I,
86        buf: &mut PgArgumentBuffer,
87    ) -> Result<IsNull, BoxDynError> {
88        let lower_size_hint = iter.size_hint().0;
89        let first = iter.next();
90        let type_info = first
91            .as_ref()
92            .and_then(Encode::produces)
93            .unwrap_or_else(<I as Iterator>::Item::type_info);
94
95        buf.extend(&1_i32.to_be_bytes()); // number of dimensions
96        buf.extend(&0_i32.to_be_bytes()); // flags
97
98        match type_info.0 {
99            PgType::DeclareWithName(name) => buf.patch_type_by_name(&name),
100            PgType::DeclareArrayOf(array) => buf.patch_array_type(array),
101
102            ty => {
103                buf.extend(&ty.oid().0.to_be_bytes());
104            }
105        }
106
107        let len_start = buf.len();
108        buf.extend(0_i32.to_be_bytes()); // len (unknown so far)
109        buf.extend(1_i32.to_be_bytes()); // lower bound
110
111        match first {
112            Some(first) => buf.encode(first)?,
113            None => return Ok(IsNull::No),
114        }
115
116        let mut count = 1_i32;
117        const MAX: usize = i32::MAX as usize - 1;
118
119        for value in (&mut iter).take(MAX) {
120            buf.encode(value)?;
121            count += 1;
122        }
123
124        const OVERFLOW: usize = i32::MAX as usize + 1;
125        if iter.next().is_some() {
126            let iter_size = std::cmp::max(lower_size_hint, OVERFLOW);
127            return Err(format!("encoded iterator is too large for Postgres: {iter_size}").into());
128        }
129
130        // set the length now that we know what it is.
131        buf[len_start..(len_start + 4)].copy_from_slice(&count.to_be_bytes());
132
133        Ok(IsNull::No)
134    }
135}
136
137impl<'q, I> Encode<'q, Postgres> for PgBindIter<I>
138where
139    I: Iterator,
140    <I as Iterator>::Item: Type<Postgres> + Encode<'q, Postgres>,
141{
142    fn encode_by_ref(&self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError> {
143        Self::encode_inner(self.0.take().expect("PgBindIter is only used once"), buf)
144    }
145    fn encode(self, buf: &mut PgArgumentBuffer) -> Result<IsNull, BoxDynError>
146    where
147        Self: Sized,
148    {
149        Self::encode_inner(
150            self.0.into_inner().expect("PgBindIter is only used once"),
151            buf,
152        )
153    }
154}