sqlx_postgres/
arguments.rs

1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::ext::ustr::UStr;
8use crate::types::Type;
9use crate::{PgConnection, PgTypeInfo, Postgres};
10
11use crate::type_info::PgArrayOf;
12pub(crate) use sqlx_core::arguments::Arguments;
13use sqlx_core::error::BoxDynError;
14
15// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
16// TODO: Extend the patch system to support dynamic lengths
17//       Considerations:
18//          - The prefixed-len offset needs to be back-tracked and updated
19//          - message::Bind needs to take a &PgArguments and use a `write` method instead of
20//            referencing a buffer directly
21//          - The basic idea is that we write bytes for the buffer until we get somewhere
22//            that has a patch, we then apply the patch which should write to &mut Vec<u8>,
23//            backtrack and update the prefixed-len, then write until the next patch offset
24
25#[derive(Default, Debug, Clone)]
26pub struct PgArgumentBuffer {
27    buffer: Vec<u8>,
28
29    // Number of arguments
30    count: usize,
31
32    // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types
33    // it can use `patch`.
34    //
35    // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
36    // tweaked from the input type. However, that's the only use case we currently have.
37    patches: Vec<Patch>,
38
39    // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
40    // It pushes a "hole" that must be patched later.
41    //
42    // The hole is a `usize` offset into the buffer with the type name that should be resolved
43    // This is done for Records and Arrays as the OID is needed well before we are in an async
44    // function and can just ask postgres.
45    //
46    type_holes: Vec<(usize, HoleKind)>, // Vec<{ offset, type_name }>
47}
48
49#[derive(Debug, Clone)]
50enum HoleKind {
51    Type { name: UStr },
52    Array(Arc<PgArrayOf>),
53}
54
55#[derive(Clone)]
56struct Patch {
57    buf_offset: usize,
58    arg_index: usize,
59    #[allow(clippy::type_complexity)]
60    callback: Arc<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
61}
62
63impl fmt::Debug for Patch {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("Patch")
66            .field("buf_offset", &self.buf_offset)
67            .field("arg_index", &self.arg_index)
68            .field("callback", &"<callback>")
69            .finish()
70    }
71}
72
73/// Implementation of [`Arguments`] for PostgreSQL.
74#[derive(Default, Debug, Clone)]
75pub struct PgArguments {
76    // Types of each bind parameter
77    pub(crate) types: Vec<PgTypeInfo>,
78
79    // Buffer of encoded bind parameters
80    pub(crate) buffer: PgArgumentBuffer,
81}
82
83impl PgArguments {
84    pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
85    where
86        T: Encode<'q, Postgres> + Type<Postgres>,
87    {
88        let type_info = value.produces().unwrap_or_else(T::type_info);
89
90        let buffer_snapshot = self.buffer.snapshot();
91
92        // encode the value into our buffer
93        if let Err(error) = self.buffer.encode(value) {
94            // reset the value buffer to its previous value if encoding failed,
95            // so we don't leave a half-encoded value behind
96            self.buffer.reset_to_snapshot(buffer_snapshot);
97            return Err(error);
98        };
99
100        // remember the type information for this value
101        self.types.push(type_info);
102        // increment the number of arguments we are tracking
103        self.buffer.count += 1;
104
105        Ok(())
106    }
107
108    // Apply patches
109    // This should only go out and ask postgres if we have not seen the type name yet
110    pub(crate) async fn apply_patches(
111        &mut self,
112        conn: &mut PgConnection,
113        parameters: &[PgTypeInfo],
114    ) -> Result<(), Error> {
115        let PgArgumentBuffer {
116            ref patches,
117            ref type_holes,
118            ref mut buffer,
119            ..
120        } = self.buffer;
121
122        for patch in patches {
123            let buf = &mut buffer[patch.buf_offset..];
124            let ty = &parameters[patch.arg_index];
125
126            (patch.callback)(buf, ty);
127        }
128
129        for (offset, kind) in type_holes {
130            let oid = match kind {
131                HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
132                HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
133            };
134            buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
135        }
136
137        Ok(())
138    }
139}
140
141impl<'q> Arguments<'q> for PgArguments {
142    type Database = Postgres;
143
144    fn reserve(&mut self, additional: usize, size: usize) {
145        self.types.reserve(additional);
146        self.buffer.reserve(size);
147    }
148
149    fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
150    where
151        T: Encode<'q, Self::Database> + Type<Self::Database>,
152    {
153        self.add(value)
154    }
155
156    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
157        write!(writer, "${}", self.buffer.count)
158    }
159
160    #[inline(always)]
161    fn len(&self) -> usize {
162        self.buffer.count
163    }
164}
165
166impl PgArgumentBuffer {
167    pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
168    where
169        T: Encode<'q, Postgres>,
170    {
171        // Won't catch everything but is a good sanity check
172        value_size_int4_checked(value.size_hint())?;
173
174        // reserve space to write the prefixed length of the value
175        let offset = self.len();
176
177        self.extend(&[0; 4]);
178
179        // encode the value into our buffer
180        let len = if let IsNull::No = value.encode(self)? {
181            // Ensure that the value size does not overflow i32
182            value_size_int4_checked(self.len() - offset - 4)?
183        } else {
184            // Write a -1 to indicate NULL
185            // NOTE: It is illegal for [encode] to write any data
186            debug_assert_eq!(self.len(), offset + 4);
187            -1_i32
188        };
189
190        // write the len to the beginning of the value
191        // (offset + 4) cannot overflow because it would have failed at `self.extend()`.
192        self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
193
194        Ok(())
195    }
196
197    // Adds a callback to be invoked later when we know the parameter type
198    #[allow(dead_code)]
199    pub(crate) fn patch<F>(&mut self, callback: F)
200    where
201        F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
202    {
203        let offset = self.len();
204        let arg_index = self.count;
205
206        self.patches.push(Patch {
207            buf_offset: offset,
208            arg_index,
209            callback: Arc::new(callback),
210        });
211    }
212
213    // Extends the inner buffer by enough space to have an OID
214    // Remembers where the OID goes and type name for the OID
215    pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
216        let offset = self.len();
217
218        self.extend_from_slice(&0_u32.to_be_bytes());
219        self.type_holes.push((
220            offset,
221            HoleKind::Type {
222                name: type_name.clone(),
223            },
224        ));
225    }
226
227    pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
228        let offset = self.len();
229
230        self.extend_from_slice(&0_u32.to_be_bytes());
231        self.type_holes.push((offset, HoleKind::Array(array)));
232    }
233
234    fn snapshot(&self) -> PgArgumentBufferSnapshot {
235        let Self {
236            buffer,
237            count,
238            patches,
239            type_holes,
240        } = self;
241
242        PgArgumentBufferSnapshot {
243            buffer_length: buffer.len(),
244            count: *count,
245            patches_length: patches.len(),
246            type_holes_length: type_holes.len(),
247        }
248    }
249
250    fn reset_to_snapshot(
251        &mut self,
252        PgArgumentBufferSnapshot {
253            buffer_length,
254            count,
255            patches_length,
256            type_holes_length,
257        }: PgArgumentBufferSnapshot,
258    ) {
259        self.buffer.truncate(buffer_length);
260        self.count = count;
261        self.patches.truncate(patches_length);
262        self.type_holes.truncate(type_holes_length);
263    }
264}
265
266struct PgArgumentBufferSnapshot {
267    buffer_length: usize,
268    count: usize,
269    patches_length: usize,
270    type_holes_length: usize,
271}
272
273impl Deref for PgArgumentBuffer {
274    type Target = Vec<u8>;
275
276    #[inline]
277    fn deref(&self) -> &Self::Target {
278        &self.buffer
279    }
280}
281
282impl DerefMut for PgArgumentBuffer {
283    #[inline]
284    fn deref_mut(&mut self) -> &mut Self::Target {
285        &mut self.buffer
286    }
287}
288
289pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
290    i32::try_from(size).map_err(|_| {
291        format!(
292            "value size would overflow in the binary protocol encoding: {size} > {}",
293            i32::MAX
294        )
295    })
296}