sqlx_build_trust_postgres/
arguments.rs

1use std::fmt::{self, Write};
2use std::ops::{Deref, DerefMut};
3
4use crate::encode::{Encode, IsNull};
5use crate::error::Error;
6use crate::ext::ustr::UStr;
7use crate::types::Type;
8use crate::{PgConnection, PgTypeInfo, Postgres};
9
10pub(crate) use sqlx_core::arguments::Arguments;
11
12// TODO: buf.patch(|| ...) is a poor name, can we think of a better name? Maybe `buf.lazy(||)` ?
13// TODO: Extend the patch system to support dynamic lengths
14//       Considerations:
15//          - The prefixed-len offset needs to be back-tracked and updated
16//          - message::Bind needs to take a &PgArguments and use a `write` method instead of
17//            referencing a buffer directly
18//          - The basic idea is that we write bytes for the buffer until we get somewhere
19//            that has a patch, we then apply the patch which should write to &mut Vec<u8>,
20//            backtrack and update the prefixed-len, then write until the next patch offset
21
22#[derive(Default)]
23pub struct PgArgumentBuffer {
24    buffer: Vec<u8>,
25
26    // Number of arguments
27    count: usize,
28
29    // Whenever an `Encode` impl needs to defer some work until after we resolve parameter types
30    // it can use `patch`.
31    //
32    // This currently is only setup to be useful if there is a *fixed-size* slot that needs to be
33    // tweaked from the input type. However, that's the only use case we currently have.
34    //
35    patches: Vec<(
36        usize, // offset
37        usize, // argument index
38        Box<dyn Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync>,
39    )>,
40
41    // Whenever an `Encode` impl encounters a `PgTypeInfo` object that does not have an OID
42    // It pushes a "hole" that must be patched later.
43    //
44    // The hole is a `usize` offset into the buffer with the type name that should be resolved
45    // This is done for Records and Arrays as the OID is needed well before we are in an async
46    // function and can just ask postgres.
47    //
48    type_holes: Vec<(usize, UStr)>, // Vec<{ offset, type_name }>
49}
50
51/// Implementation of [`Arguments`] for PostgreSQL.
52#[derive(Default)]
53pub struct PgArguments {
54    // Types of each bind parameter
55    pub(crate) types: Vec<PgTypeInfo>,
56
57    // Buffer of encoded bind parameters
58    pub(crate) buffer: PgArgumentBuffer,
59}
60
61impl PgArguments {
62    pub(crate) fn add<'q, T>(&mut self, value: T)
63    where
64        T: Encode<'q, Postgres> + Type<Postgres>,
65    {
66        // remember the type information for this value
67        self.types
68            .push(value.produces().unwrap_or_else(T::type_info));
69
70        // encode the value into our buffer
71        self.buffer.encode(value);
72
73        // increment the number of arguments we are tracking
74        self.buffer.count += 1;
75    }
76
77    // Apply patches
78    // This should only go out and ask postgres if we have not seen the type name yet
79    pub(crate) async fn apply_patches(
80        &mut self,
81        conn: &mut PgConnection,
82        parameters: &[PgTypeInfo],
83    ) -> Result<(), Error> {
84        let PgArgumentBuffer {
85            ref patches,
86            ref type_holes,
87            ref mut buffer,
88            ..
89        } = self.buffer;
90
91        for (offset, ty, callback) in patches {
92            let buf = &mut buffer[*offset..];
93            let ty = &parameters[*ty];
94
95            callback(buf, ty);
96        }
97
98        for (offset, name) in type_holes {
99            let oid = conn.fetch_type_id_by_name(&*name).await?;
100            buffer[*offset..(*offset + 4)].copy_from_slice(&oid.0.to_be_bytes());
101        }
102
103        Ok(())
104    }
105}
106
107impl<'q> Arguments<'q> for PgArguments {
108    type Database = Postgres;
109
110    fn reserve(&mut self, additional: usize, size: usize) {
111        self.types.reserve(additional);
112        self.buffer.reserve(size);
113    }
114
115    fn add<T>(&mut self, value: T)
116    where
117        T: Encode<'q, Self::Database> + Type<Self::Database>,
118    {
119        self.add(value)
120    }
121
122    fn format_placeholder<W: Write>(&self, writer: &mut W) -> fmt::Result {
123        write!(writer, "${}", self.buffer.count)
124    }
125}
126
127impl PgArgumentBuffer {
128    pub(crate) fn encode<'q, T>(&mut self, value: T)
129    where
130        T: Encode<'q, Postgres>,
131    {
132        // reserve space to write the prefixed length of the value
133        let offset = self.len();
134        self.extend(&[0; 4]);
135
136        // encode the value into our buffer
137        let len = if let IsNull::No = value.encode(self) {
138            (self.len() - offset - 4) as i32
139        } else {
140            // Write a -1 to indicate NULL
141            // NOTE: It is illegal for [encode] to write any data
142            debug_assert_eq!(self.len(), offset + 4);
143            -1_i32
144        };
145
146        // write the len to the beginning of the value
147        self[offset..(offset + 4)].copy_from_slice(&len.to_be_bytes());
148    }
149
150    // Adds a callback to be invoked later when we know the parameter type
151    #[allow(dead_code)]
152    pub(crate) fn patch<F>(&mut self, callback: F)
153    where
154        F: Fn(&mut [u8], &PgTypeInfo) + 'static + Send + Sync,
155    {
156        let offset = self.len();
157        let index = self.count;
158
159        self.patches.push((offset, index, Box::new(callback)));
160    }
161
162    // Extends the inner buffer by enough space to have an OID
163    // Remembers where the OID goes and type name for the OID
164    pub(crate) fn patch_type_by_name(&mut self, type_name: &UStr) {
165        let offset = self.len();
166
167        self.extend_from_slice(&0_u32.to_be_bytes());
168        self.type_holes.push((offset, type_name.clone()));
169    }
170}
171
172impl Deref for PgArgumentBuffer {
173    type Target = Vec<u8>;
174
175    #[inline]
176    fn deref(&self) -> &Self::Target {
177        &self.buffer
178    }
179}
180
181impl DerefMut for PgArgumentBuffer {
182    #[inline]
183    fn deref_mut(&mut self) -> &mut Self::Target {
184        &mut self.buffer
185    }
186}