xitca_postgres/driver/codec/
encode.rs

1use std::sync::Arc;
2
3use postgres_protocol::message::frontend;
4use xitca_io::bytes::BytesMut;
5
6use crate::{
7    column::Column,
8    error::{Error, InvalidParamCount},
9    prepare::Prepare,
10    statement::{
11        Statement, StatementCreate, StatementCreateBlocking, StatementPreparedQuery, StatementPreparedQueryOwned,
12        StatementQuery, StatementSingleRTTQueryWithCli,
13    },
14    types::{BorrowToSql, IsNull, Type},
15};
16
17use super::{
18    AsParams, DriverTx, Response,
19    response::{
20        IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
21    },
22    sealed,
23};
24
25/// trait for generic over how to encode a query.
26/// currently this trait can not be implement by library user.
27#[diagnostic::on_unimplemented(
28    message = "`{Self}` does not impl Encode trait",
29    label = "query statement argument must be types implement Encode trait",
30    note = "consider using the types listed below that implementing Encode trait"
31)]
32pub trait Encode: sealed::Sealed + Sized {
33    /// output type defines how a potential async row streaming type should be constructed.
34    /// certain state from the encode type may need to be passed for constructing the stream
35    type Output: IntoResponse;
36
37    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
38}
39
40impl sealed::Sealed for &str {}
41
42impl Encode for &str {
43    type Output = Vec<Column>;
44
45    #[inline]
46    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
47        frontend::query(self, buf)?;
48        Ok(Vec::new())
49    }
50}
51
52impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}
53
54impl<'c, C> Encode for StatementCreate<'_, 'c, C>
55where
56    C: Prepare,
57{
58    type Output = StatementCreateResponse<'c, C>;
59
60    #[inline]
61    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
62        let Self { name, stmt, types, cli } = self;
63        encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
64    }
65}
66
67impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}
68
69impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
70where
71    C: Prepare,
72{
73    type Output = StatementCreateResponseBlocking<'c, C>;
74
75    #[inline]
76    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
77        let Self { name, stmt, types, cli } = self;
78        encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
79    }
80}
81
82fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
83    frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
84    frontend::describe(b'S', name, buf)?;
85    frontend::sync(buf);
86    Ok(())
87}
88
89pub(crate) struct StatementCancel<'a> {
90    pub(crate) name: &'a str,
91}
92
93impl sealed::Sealed for StatementCancel<'_> {}
94
95impl Encode for StatementCancel<'_> {
96    type Output = NoOpIntoRowStream;
97
98    #[inline]
99    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
100        let Self { name } = self;
101        frontend::close(b'S', name, buf)?;
102        frontend::sync(buf);
103        Ok(NoOpIntoRowStream)
104    }
105}
106
107impl<P> sealed::Sealed for StatementPreparedQuery<'_, P> {}
108
109impl<'s, P> Encode for StatementPreparedQuery<'s, P>
110where
111    P: AsParams,
112{
113    type Output = &'s [Column];
114
115    #[inline]
116    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
117        let Self { stmt, params } = self;
118        encode_stmt_query(stmt, params, buf).map(|_| stmt.columns())
119    }
120}
121
122impl<P> sealed::Sealed for StatementPreparedQueryOwned<'_, P> {}
123
124impl<'s, P> Encode for StatementPreparedQueryOwned<'s, P>
125where
126    P: AsParams,
127{
128    type Output = Arc<[Column]>;
129
130    #[inline]
131    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
132        let Self { stmt, params } = self;
133        encode_stmt_query(stmt, params, buf).map(|_| stmt.columns_owned())
134    }
135}
136
137fn encode_stmt_query<P>(stmt: &Statement, params: P, buf: &mut BytesMut) -> Result<(), Error>
138where
139    P: AsParams,
140{
141    encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
142    frontend::execute("", 0, buf)?;
143    frontend::sync(buf);
144    Ok(())
145}
146
147impl<C, P> sealed::Sealed for StatementSingleRTTQueryWithCli<'_, '_, P, C> {}
148
149impl<'c, C, P> Encode for StatementSingleRTTQueryWithCli<'_, 'c, P, C>
150where
151    C: Prepare,
152    P: AsParams,
153{
154    type Output = IntoRowStreamGuard<'c, C>;
155
156    #[inline]
157    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
158        let Self { query, cli } = self;
159        let StatementQuery { stmt, params, types } = query;
160        frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
161        encode_bind("", types, params, "", buf)?;
162        frontend::describe(b'S', "", buf)?;
163        frontend::execute("", 0, buf)?;
164        frontend::sync(buf);
165        Ok(IntoRowStreamGuard(cli))
166    }
167}
168
169pub(crate) struct PortalCreate<'a, P> {
170    pub(crate) name: &'a str,
171    pub(crate) stmt: &'a str,
172    pub(crate) types: &'a [Type],
173    pub(crate) params: P,
174}
175
176impl<P> sealed::Sealed for PortalCreate<'_, P> {}
177
178impl<P> Encode for PortalCreate<'_, P>
179where
180    P: AsParams,
181{
182    type Output = NoOpIntoRowStream;
183
184    #[inline]
185    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
186        let PortalCreate {
187            name,
188            stmt,
189            types,
190            params,
191        } = self;
192        encode_bind(stmt, types, params, name, buf)?;
193        frontend::sync(buf);
194        Ok(NoOpIntoRowStream)
195    }
196}
197
198pub(crate) struct PortalCancel<'a> {
199    pub(crate) name: &'a str,
200}
201
202impl sealed::Sealed for PortalCancel<'_> {}
203
204impl Encode for PortalCancel<'_> {
205    type Output = NoOpIntoRowStream;
206
207    #[inline]
208    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
209        frontend::close(b'P', self.name, buf)?;
210        frontend::sync(buf);
211        Ok(NoOpIntoRowStream)
212    }
213}
214
215pub struct PortalQuery<'a> {
216    pub(crate) name: &'a str,
217    pub(crate) columns: &'a [Column],
218    pub(crate) max_rows: i32,
219}
220
221impl sealed::Sealed for PortalQuery<'_> {}
222
223impl<'s> Encode for PortalQuery<'s> {
224    type Output = &'s [Column];
225
226    #[inline]
227    fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
228        let Self {
229            name,
230            max_rows,
231            columns,
232        } = self;
233        frontend::execute(name, max_rows, buf)?;
234        frontend::sync(buf);
235        Ok(columns)
236    }
237}
238
239pub(crate) fn send_encode_query<S>(tx: &DriverTx, stmt: S) -> Result<(S::Output, Response), Error>
240where
241    S: Encode,
242{
243    tx.send(|buf| stmt.encode(buf))
244}
245
246fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
247where
248    P: AsParams,
249{
250    let params = params.into_iter();
251    if params.len() != types.len() {
252        return Err(Error::from(InvalidParamCount {
253            expected: types.len(),
254            params: params.len(),
255        }));
256    }
257
258    let params = params.zip(types);
259
260    frontend::bind(
261        portal,
262        stmt,
263        params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
264        params,
265        |(p, ty), buf| {
266            p.borrow_to_sql().to_sql_checked(ty, buf).map(|is_null| match is_null {
267                IsNull::No => postgres_protocol::IsNull::No,
268                IsNull::Yes => postgres_protocol::IsNull::Yes,
269            })
270        },
271        Some(1),
272        buf,
273    )
274    .map_err(|e| match e {
275        frontend::BindError::Conversion(e) => Error::from(e),
276        frontend::BindError::Serialization(e) => Error::from(e),
277    })
278}