1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::ext::ustr::UStr;
8use crate::types::Type;
9use crate::{PgConnection, PgTypeInfo, Postgres};
10
11use crate::type_info::PgArrayOf;
12pub(crate) use sqlx_core::arguments::Arguments;
13use sqlx_core::error::BoxDynError;
14
15#[derive(Default, Debug, Clone)]
26pub struct PgArgumentBuffer {
27    buffer: Vec<u8>,
28
29    count: usize,
31
32    patches: Vec<Patch>,
38
39    type_holes: Vec<(usize, HoleKind)>, }
48
49#[derive(Debug, Clone)]
50enum HoleKind {
51    Type { name: UStr },
52    Array(Arc<PgArrayOf>),
53}
54
55#[derive(Clone)]
56struct Patch {
57    buf_offset: usize,
58    arg_index: usize,
59    #[allow(clippy::type_complexity)]
60    callback: Arc<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
61}
62
63impl fmt::Debug for Patch {
64    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65        f.debug_struct("Patch")
66            .field("buf_offset", &self.buf_offset)
67            .field("arg_index", &self.arg_index)
68            .field("callback", &"<callback>")
69            .finish()
70    }
71}
72
73#[derive(Default, Debug, Clone)]
75pub struct PgArguments {
76    pub(crate) types: Vec<PgTypeInfo>,
78
79    pub(crate) buffer: PgArgumentBuffer,
81}
82
83impl PgArguments {
84    pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
85    where
86        T: Encode<'q, Postgres> + Type<Postgres>,
87    {
88        let type_info = value.produces().unwrap_or_else(T::type_info);
89
90        let buffer_snapshot = self.buffer.snapshot();
91
92        if let Err(error) = self.buffer.encode(value) {
94            self.buffer.reset_to_snapshot(buffer_snapshot);
97            return Err(error);
98        };
99
100        self.types.push(type_info);
102        self.buffer.count += 1;
104
105        Ok(())
106    }
107
108    pub(crate) async fn apply_patches(
111        &mut self,
112        conn: &mut PgConnection,
113        parameters: &[PgTypeInfo],
114    ) -> Result<(), Error> {
115        let PgArgumentBuffer {
116            ref patches,
117            ref type_holes,
118            ref mut buffer,
119            ..
120        } = self.buffer;
121
122        for patch in patches {
123            let buf = &mut buffer[patch.buf_offset..];
124            let ty = ¶meters[patch.arg_index];
125
126            (patch.callback)(buf, ty);
127        }
128
129        for (offset, kind) in type_holes {
130            let oid = match kind {
131                HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
132                HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
133            };
134            buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
135        }
136
137        Ok(())
138    }
139}
140
141impl Arguments for PgArguments {
142    type Database = Postgres;
143
144    fn reserve(&mut self, additional: usize, size: usize) {
145        self.types.reserve(additional);
146        self.buffer.reserve(size);
147    }
148
149    fn add<'t, T>(&mut self, value: T) -> Result<(), BoxDynError>
150    where
151        T: Encode<'t, Self::Database> + Type<Self::Database>,
152    {
153        self.add(value)
154    }
155
156    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
157        write!(writer, "${}", self.buffer.count)
158    }
159
160    #[inline(always)]
161    fn len(&self) -> usize {
162        self.buffer.count
163    }
164}
165
166impl PgArgumentBuffer {
167    pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
168    where
169        T: Encode<'q, Postgres>,
170    {
171        value_size_int4_checked(value.size_hint())?;
173
174        let offset = self.len();
176
177        self.extend(&[0; 4]);
178
179        let len = if let IsNull::No = value.encode(self)? {
181            value_size_int4_checked(self.len() - offset - 4)?
183        } else {
184            debug_assert_eq!(self.len(), offset + 4);
187            -1_i32
188        };
189
190        self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
193
194        Ok(())
195    }
196
197    #[allow(dead_code)]
199    pub(crate) fn patch<F>(&mut self, callback: F)
200    where
201        F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
202    {
203        let offset = self.len();
204        let arg_index = self.count;
205
206        self.patches.push(Patch {
207            buf_offset: offset,
208            arg_index,
209            callback: Arc::new(callback),
210        });
211    }
212
213    pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
216        let offset = self.len();
217
218        self.extend_from_slice(&0_u32.to_be_bytes());
219        self.type_holes.push((
220            offset,
221            HoleKind::Type {
222                name: type_name.clone(),
223            },
224        ));
225    }
226
227    pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
228        let offset = self.len();
229
230        self.extend_from_slice(&0_u32.to_be_bytes());
231        self.type_holes.push((offset, HoleKind::Array(array)));
232    }
233
234    fn snapshot(&self) -> PgArgumentBufferSnapshot {
235        let Self {
236            buffer,
237            count,
238            patches,
239            type_holes,
240        } = self;
241
242        PgArgumentBufferSnapshot {
243            buffer_length: buffer.len(),
244            count: *count,
245            patches_length: patches.len(),
246            type_holes_length: type_holes.len(),
247        }
248    }
249
250    fn reset_to_snapshot(
251        &mut self,
252        PgArgumentBufferSnapshot {
253            buffer_length,
254            count,
255            patches_length,
256            type_holes_length,
257        }: PgArgumentBufferSnapshot,
258    ) {
259        self.buffer.truncate(buffer_length);
260        self.count = count;
261        self.patches.truncate(patches_length);
262        self.type_holes.truncate(type_holes_length);
263    }
264}
265
266struct PgArgumentBufferSnapshot {
267    buffer_length: usize,
268    count: usize,
269    patches_length: usize,
270    type_holes_length: usize,
271}
272
273impl Deref for PgArgumentBuffer {
274    type Target = Vec<u8>;
275
276    #[inline]
277    fn deref(&self) -> &Self::Target {
278        &self.buffer
279    }
280}
281
282impl DerefMut for PgArgumentBuffer {
283    #[inline]
284    fn deref_mut(&mut self) -> &mut Self::Target {
285        &mut self.buffer
286    }
287}
288
289pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
290    i32::try_from(size).map_err(|_| {
291        format!(
292            "value size would overflow in the binary protocol encoding: {size} > {}",
293            i32::MAX
294        )
295    })
296}