1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3use std::sync::Arc;
4
5use crate::encode::{Encode, IsNull};
6use crate::error::Error;
7use crate::ext::ustr::UStr;
8use crate::types::Type;
9use crate::{PgConnection, PgTypeInfo, Postgres};
10
11use crate::type_info::PgArrayOf;
12pub(crate) use sqlx_core::arguments::Arguments;
13use sqlx_core::error::BoxDynError;
14
15#[derive(Default, Debug, Clone)]
26pub struct PgArgumentBuffer {
27 buffer: Vec<u8>,
28
29 count: usize,
31
32 patches: Vec<Patch>,
38
39 type_holes: Vec<(usize, HoleKind)>, }
48
49#[derive(Debug, Clone)]
50enum HoleKind {
51 Type { name: UStr },
52 Array(Arc<PgArrayOf>),
53}
54
55#[derive(Clone)]
56struct Patch {
57 buf_offset: usize,
58 arg_index: usize,
59 #[allow(clippy::type_complexity)]
60 callback: Arc<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
61}
62
63impl fmt::Debug for Patch {
64 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
65 f.debug_struct("Patch")
66 .field("buf_offset", &self.buf_offset)
67 .field("arg_index", &self.arg_index)
68 .field("callback", &"<callback>")
69 .finish()
70 }
71}
72
73#[derive(Default, Debug, Clone)]
75pub struct PgArguments {
76 pub(crate) types: Vec<PgTypeInfo>,
78
79 pub(crate) buffer: PgArgumentBuffer,
81}
82
83impl PgArguments {
84 pub(crate) fn add<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
85 where
86 T: Encode<'q, Postgres> + Type<Postgres>,
87 {
88 let type_info = value.produces().unwrap_or_else(T::type_info);
89
90 let buffer_snapshot = self.buffer.snapshot();
91
92 if let Err(error) = self.buffer.encode(value) {
94 self.buffer.reset_to_snapshot(buffer_snapshot);
97 return Err(error);
98 };
99
100 self.types.push(type_info);
102 self.buffer.count += 1;
104
105 Ok(())
106 }
107
108 pub(crate) async fn apply_patches(
111 &mut self,
112 conn: &mut PgConnection,
113 parameters: &[PgTypeInfo],
114 ) -> Result<(), Error> {
115 let PgArgumentBuffer {
116 ref patches,
117 ref type_holes,
118 ref mut buffer,
119 ..
120 } = self.buffer;
121
122 for patch in patches {
123 let buf = &mut buffer[patch.buf_offset..];
124 let ty = ¶meters[patch.arg_index];
125
126 (patch.callback)(buf, ty);
127 }
128
129 for (offset, kind) in type_holes {
130 let oid = match kind {
131 HoleKind::Type { name } => conn.fetch_type_id_by_name(name).await?,
132 HoleKind::Array(array) => conn.fetch_array_type_id(array).await?,
133 };
134 buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
135 }
136
137 Ok(())
138 }
139}
140
141impl<'q> Arguments<'q> for PgArguments {
142 type Database = Postgres;
143
144 fn reserve(&mut self, additional: usize, size: usize) {
145 self.types.reserve(additional);
146 self.buffer.reserve(size);
147 }
148
149 fn add<T>(&mut self, value: T) -> Result<(), BoxDynError>
150 where
151 T: Encode<'q, Self::Database> + Type<Self::Database>,
152 {
153 self.add(value)
154 }
155
156 fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
157 write!(writer, "${}", self.buffer.count)
158 }
159
160 #[inline(always)]
161 fn len(&self) -> usize {
162 self.buffer.count
163 }
164}
165
166impl PgArgumentBuffer {
167 pub(crate) fn encode<'q, T>(&mut self, value: T) -> Result<(), BoxDynError>
168 where
169 T: Encode<'q, Postgres>,
170 {
171 value_size_int4_checked(value.size_hint())?;
173
174 let offset = self.len();
176
177 self.extend(&[0; 4]);
178
179 let len = if let IsNull::No = value.encode(self)? {
181 value_size_int4_checked(self.len() - offset - 4)?
183 } else {
184 debug_assert_eq!(self.len(), offset + 4);
187 -1_i32
188 };
189
190 self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
193
194 Ok(())
195 }
196
197 #[allow(dead_code)]
199 pub(crate) fn patch<F>(&mut self, callback: F)
200 where
201 F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
202 {
203 let offset = self.len();
204 let arg_index = self.count;
205
206 self.patches.push(Patch {
207 buf_offset: offset,
208 arg_index,
209 callback: Arc::new(callback),
210 });
211 }
212
213 pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
216 let offset = self.len();
217
218 self.extend_from_slice(&0_u32.to_be_bytes());
219 self.type_holes.push((
220 offset,
221 HoleKind::Type {
222 name: type_name.clone(),
223 },
224 ));
225 }
226
227 pub(crate) fn patch_array_type(&mut self, array: Arc<PgArrayOf>) {
228 let offset = self.len();
229
230 self.extend_from_slice(&0_u32.to_be_bytes());
231 self.type_holes.push((offset, HoleKind::Array(array)));
232 }
233
234 fn snapshot(&self) -> PgArgumentBufferSnapshot {
235 let Self {
236 buffer,
237 count,
238 patches,
239 type_holes,
240 } = self;
241
242 PgArgumentBufferSnapshot {
243 buffer_length: buffer.len(),
244 count: *count,
245 patches_length: patches.len(),
246 type_holes_length: type_holes.len(),
247 }
248 }
249
250 fn reset_to_snapshot(
251 &mut self,
252 PgArgumentBufferSnapshot {
253 buffer_length,
254 count,
255 patches_length,
256 type_holes_length,
257 }: PgArgumentBufferSnapshot,
258 ) {
259 self.buffer.truncate(buffer_length);
260 self.count = count;
261 self.patches.truncate(patches_length);
262 self.type_holes.truncate(type_holes_length);
263 }
264}
265
266struct PgArgumentBufferSnapshot {
267 buffer_length: usize,
268 count: usize,
269 patches_length: usize,
270 type_holes_length: usize,
271}
272
273impl Deref for PgArgumentBuffer {
274 type Target = Vec<u8>;
275
276 #[inline]
277 fn deref(&self) -> &Self::Target {
278 &self.buffer
279 }
280}
281
282impl DerefMut for PgArgumentBuffer {
283 #[inline]
284 fn deref_mut(&mut self) -> &mut Self::Target {
285 &mut self.buffer
286 }
287}
288
289pub(crate) fn value_size_int4_checked(size: usize) -> Result<i32, String> {
290 i32::try_from(size).map_err(|_| {
291 format!(
292 "value size would overflow in the binary protocol encoding: {size} > {}",
293 i32::MAX
294 )
295 })
296}