xitca_postgres/driver/codec/
encode.rs

1use postgres_protocol::message::frontend;
2use xitca_io::bytes::BytesMut;
3
4use crate::{
5    column::Column,
6    error::{Error, InvalidParamCount},
7    pipeline::PipelineQuery,
8    prepare::Prepare,
9    statement::{Statement, StatementCreate, StatementCreateBlocking, StatementQuery, StatementUnnamedQuery},
10    types::{BorrowToSql, IsNull, Type},
11};
12
13use super::{
14    response::{
15        IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
16    },
17    sealed, AsParams, DriverTx, Response,
18};
19
20/// trait for generic over how to encode a query.
21/// currently this trait can not be implement by library user.
22#[diagnostic::on_unimplemented(
23    message = "`{Self}` does not impl Encode trait",
24    label = "query statement argument must be types implement Encode trait",
25    note = "consider using the types listed below that implementing Encode trait"
26)]
27pub trait Encode: sealed::Sealed + Sized {
28    /// output type defines how a potential async row streaming type should be constructed.
29    /// certain state from the encode type may need to be passed for constructing the stream
30    type Output: IntoResponse;
31
32    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
33
34    #[doc(hidden)]
35    /// Hinting how many response messages will be contained by this encode type.
36    /// It **MUST** be correct count if you override this method. It determine how [`Driver`] observe boundaries
37    /// between database response messages. A wrong count will kill the driver and cause [`Client`] shutdown.
38    ///
39    /// [`Driver`]: crate::driver::Driver
40    /// [`Client`]: crate::client::Client
41    #[inline(always)]
42    fn count_hint(&self) -> usize {
43        1
44    }
45}
46
47impl sealed::Sealed for &str {}
48
49impl Encode for &str {
50    type Output = Vec<Column>;
51
52    #[inline]
53    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
54        frontend::query(self, buf)?;
55        Ok(Vec::new())
56    }
57}
58
59impl sealed::Sealed for &Statement {}
60
61impl<'s> Encode for &'s Statement {
62    type Output = &'s [Column];
63
64    #[inline]
65    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
66        encode_bind(self.name(), self.params(), [] as [i32; 0], "", buf)?;
67        frontend::execute("", 0, buf)?;
68        if SYNC_MODE {
69            frontend::sync(buf);
70        }
71        Ok(self.columns())
72    }
73}
74
75impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}
76
77impl<'c, C> Encode for StatementCreate<'_, 'c, C>
78where
79    C: Prepare,
80{
81    type Output = StatementCreateResponse<'c, C>;
82
83    #[inline]
84    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
85        let Self { name, stmt, types, cli } = self;
86        encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
87    }
88}
89
90impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}
91
92impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
93where
94    C: Prepare,
95{
96    type Output = StatementCreateResponseBlocking<'c, C>;
97
98    #[inline]
99    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
100        let Self { name, stmt, types, cli } = self;
101        encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
102    }
103}
104
105fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
106    frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
107    frontend::describe(b'S', name, buf)?;
108    frontend::sync(buf);
109    Ok(())
110}
111
112pub(crate) struct StatementCancel<'a> {
113    pub(crate) name: &'a str,
114}
115
116impl sealed::Sealed for StatementCancel<'_> {}
117
118impl Encode for StatementCancel<'_> {
119    type Output = NoOpIntoRowStream;
120
121    #[inline]
122    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
123        let Self { name } = self;
124        frontend::close(b'S', name, buf)?;
125        frontend::sync(buf);
126        Ok(NoOpIntoRowStream)
127    }
128}
129
130impl<P> sealed::Sealed for StatementQuery<'_, P> {}
131
132impl<'s, P> Encode for StatementQuery<'s, P>
133where
134    P: AsParams,
135{
136    type Output = &'s [Column];
137
138    #[inline]
139    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
140        let StatementQuery { stmt, params } = self;
141        encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
142        frontend::execute("", 0, buf)?;
143        if SYNC_MODE {
144            frontend::sync(buf);
145        }
146        Ok(stmt.columns())
147    }
148}
149
150impl<C, P> sealed::Sealed for StatementUnnamedQuery<'_, '_, P, C> {}
151
152impl<'c, C, P> Encode for StatementUnnamedQuery<'_, 'c, P, C>
153where
154    C: Prepare,
155    P: AsParams,
156{
157    type Output = IntoRowStreamGuard<'c, C>;
158
159    #[inline]
160    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
161        let Self {
162            stmt,
163            types,
164            cli,
165            params,
166        } = self;
167        frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
168        encode_bind("", types, params, "", buf)?;
169        frontend::describe(b'S', "", buf)?;
170        frontend::execute("", 0, buf)?;
171        if SYNC_MODE {
172            frontend::sync(buf);
173        }
174        Ok(IntoRowStreamGuard(cli))
175    }
176}
177
178pub(crate) struct PortalCreate<'a, P> {
179    pub(crate) name: &'a str,
180    pub(crate) stmt: &'a str,
181    pub(crate) types: &'a [Type],
182    pub(crate) params: P,
183}
184
185impl<P> sealed::Sealed for PortalCreate<'_, P> {}
186
187impl<P> Encode for PortalCreate<'_, P>
188where
189    P: AsParams,
190{
191    type Output = NoOpIntoRowStream;
192
193    #[inline]
194    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
195        let PortalCreate {
196            name,
197            stmt,
198            types,
199            params,
200        } = self;
201        encode_bind(stmt, types, params, name, buf)?;
202        frontend::sync(buf);
203        Ok(NoOpIntoRowStream)
204    }
205}
206
207pub(crate) struct PortalCancel<'a> {
208    pub(crate) name: &'a str,
209}
210
211impl sealed::Sealed for PortalCancel<'_> {}
212
213impl Encode for PortalCancel<'_> {
214    type Output = NoOpIntoRowStream;
215
216    #[inline]
217    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
218        frontend::close(b'P', self.name, buf)?;
219        frontend::sync(buf);
220        Ok(NoOpIntoRowStream)
221    }
222}
223
224pub struct PortalQuery<'a> {
225    pub(crate) name: &'a str,
226    pub(crate) columns: &'a [Column],
227    pub(crate) max_rows: i32,
228}
229
230impl sealed::Sealed for PortalQuery<'_> {}
231
232impl<'s> Encode for PortalQuery<'s> {
233    type Output = &'s [Column];
234
235    #[inline]
236    fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
237        let Self {
238            name,
239            max_rows,
240            columns,
241        } = self;
242        frontend::execute(name, max_rows, buf)?;
243        frontend::sync(buf);
244        Ok(columns)
245    }
246}
247
248impl sealed::Sealed for PipelineQuery<'_, '_> {}
249
250impl<'s> Encode for PipelineQuery<'s, '_> {
251    type Output = Vec<&'s [Column]>;
252
253    #[inline]
254    fn encode<const SYNC_MODE: bool>(self, buf_drv: &mut BytesMut) -> Result<Self::Output, Error> {
255        let Self { columns, buf, .. } = self;
256        buf_drv.extend_from_slice(buf);
257        Ok(columns)
258    }
259
260    #[inline(always)]
261    fn count_hint(&self) -> usize {
262        self.count
263    }
264}
265
266pub(crate) fn send_encode_query<S>(tx: &DriverTx, stmt: S) -> Result<(S::Output, Response), Error>
267where
268    S: Encode,
269{
270    let msg_count = stmt.count_hint();
271    tx.send(|buf| stmt.encode::<true>(buf), msg_count)
272}
273
274fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
275where
276    P: AsParams,
277{
278    let params = params.into_iter();
279    if params.len() != types.len() {
280        return Err(Error::from(InvalidParamCount {
281            expected: types.len(),
282            params: params.len(),
283        }));
284    }
285
286    let params = params.zip(types).collect::<Vec<_>>();
287
288    frontend::bind(
289        portal,
290        stmt,
291        params.iter().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
292        params.iter(),
293        |(param, ty), buf| {
294            param
295                .borrow_to_sql()
296                .to_sql_checked(ty, buf)
297                .map(|is_null| match is_null {
298                    IsNull::No => postgres_protocol::IsNull::No,
299                    IsNull::Yes => postgres_protocol::IsNull::Yes,
300                })
301        },
302        Some(1),
303        buf,
304    )
305    .map_err(|e| match e {
306        frontend::BindError::Conversion(e) => Error::from(e),
307        frontend::BindError::Serialization(e) => Error::from(e),
308    })
309}