1use postgres_protocol::message::frontend;
2use xitca_io::bytes::BytesMut;
3
4use crate::{
5 column::Column,
6 error::{Error, InvalidParamCount},
7 pipeline::PipelineQuery,
8 prepare::Prepare,
9 statement::{Statement, StatementCreate, StatementCreateBlocking, StatementQuery, StatementUnnamedQuery},
10 types::{BorrowToSql, IsNull, Type},
11};
12
13use super::{
14 response::{
15 IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
16 },
17 sealed, AsParams, DriverTx, Response,
18};
19
20#[diagnostic::on_unimplemented(
23 message = "`{Self}` does not impl Encode trait",
24 label = "query statement argument must be types implement Encode trait",
25 note = "consider using the types listed below that implementing Encode trait"
26)]
27pub trait Encode: sealed::Sealed + Sized {
28 type Output: IntoResponse;
31
32 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
33
34 #[doc(hidden)]
35 #[inline(always)]
42 fn count_hint(&self) -> usize {
43 1
44 }
45}
46
47impl sealed::Sealed for &str {}
48
49impl Encode for &str {
50 type Output = Vec<Column>;
51
52 #[inline]
53 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
54 frontend::query(self, buf)?;
55 Ok(Vec::new())
56 }
57}
58
59impl sealed::Sealed for &Statement {}
60
61impl<'s> Encode for &'s Statement {
62 type Output = &'s [Column];
63
64 #[inline]
65 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
66 encode_bind(self.name(), self.params(), [] as [i32; 0], "", buf)?;
67 frontend::execute("", 0, buf)?;
68 if SYNC_MODE {
69 frontend::sync(buf);
70 }
71 Ok(self.columns())
72 }
73}
74
75impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}
76
77impl<'c, C> Encode for StatementCreate<'_, 'c, C>
78where
79 C: Prepare,
80{
81 type Output = StatementCreateResponse<'c, C>;
82
83 #[inline]
84 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
85 let Self { name, stmt, types, cli } = self;
86 encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
87 }
88}
89
90impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}
91
92impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
93where
94 C: Prepare,
95{
96 type Output = StatementCreateResponseBlocking<'c, C>;
97
98 #[inline]
99 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
100 let Self { name, stmt, types, cli } = self;
101 encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
102 }
103}
104
105fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
106 frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
107 frontend::describe(b'S', name, buf)?;
108 frontend::sync(buf);
109 Ok(())
110}
111
112pub(crate) struct StatementCancel<'a> {
113 pub(crate) name: &'a str,
114}
115
116impl sealed::Sealed for StatementCancel<'_> {}
117
118impl Encode for StatementCancel<'_> {
119 type Output = NoOpIntoRowStream;
120
121 #[inline]
122 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
123 let Self { name } = self;
124 frontend::close(b'S', name, buf)?;
125 frontend::sync(buf);
126 Ok(NoOpIntoRowStream)
127 }
128}
129
130impl<P> sealed::Sealed for StatementQuery<'_, P> {}
131
132impl<'s, P> Encode for StatementQuery<'s, P>
133where
134 P: AsParams,
135{
136 type Output = &'s [Column];
137
138 #[inline]
139 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
140 let StatementQuery { stmt, params } = self;
141 encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
142 frontend::execute("", 0, buf)?;
143 if SYNC_MODE {
144 frontend::sync(buf);
145 }
146 Ok(stmt.columns())
147 }
148}
149
150impl<C, P> sealed::Sealed for StatementUnnamedQuery<'_, '_, P, C> {}
151
152impl<'c, C, P> Encode for StatementUnnamedQuery<'_, 'c, P, C>
153where
154 C: Prepare,
155 P: AsParams,
156{
157 type Output = IntoRowStreamGuard<'c, C>;
158
159 #[inline]
160 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
161 let Self {
162 stmt,
163 types,
164 cli,
165 params,
166 } = self;
167 frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
168 encode_bind("", types, params, "", buf)?;
169 frontend::describe(b'S', "", buf)?;
170 frontend::execute("", 0, buf)?;
171 if SYNC_MODE {
172 frontend::sync(buf);
173 }
174 Ok(IntoRowStreamGuard(cli))
175 }
176}
177
178pub(crate) struct PortalCreate<'a, P> {
179 pub(crate) name: &'a str,
180 pub(crate) stmt: &'a str,
181 pub(crate) types: &'a [Type],
182 pub(crate) params: P,
183}
184
185impl<P> sealed::Sealed for PortalCreate<'_, P> {}
186
187impl<P> Encode for PortalCreate<'_, P>
188where
189 P: AsParams,
190{
191 type Output = NoOpIntoRowStream;
192
193 #[inline]
194 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
195 let PortalCreate {
196 name,
197 stmt,
198 types,
199 params,
200 } = self;
201 encode_bind(stmt, types, params, name, buf)?;
202 frontend::sync(buf);
203 Ok(NoOpIntoRowStream)
204 }
205}
206
207pub(crate) struct PortalCancel<'a> {
208 pub(crate) name: &'a str,
209}
210
211impl sealed::Sealed for PortalCancel<'_> {}
212
213impl Encode for PortalCancel<'_> {
214 type Output = NoOpIntoRowStream;
215
216 #[inline]
217 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
218 frontend::close(b'P', self.name, buf)?;
219 frontend::sync(buf);
220 Ok(NoOpIntoRowStream)
221 }
222}
223
224pub struct PortalQuery<'a> {
225 pub(crate) name: &'a str,
226 pub(crate) columns: &'a [Column],
227 pub(crate) max_rows: i32,
228}
229
230impl sealed::Sealed for PortalQuery<'_> {}
231
232impl<'s> Encode for PortalQuery<'s> {
233 type Output = &'s [Column];
234
235 #[inline]
236 fn encode<const SYNC_MODE: bool>(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
237 let Self {
238 name,
239 max_rows,
240 columns,
241 } = self;
242 frontend::execute(name, max_rows, buf)?;
243 frontend::sync(buf);
244 Ok(columns)
245 }
246}
247
248impl sealed::Sealed for PipelineQuery<'_, '_> {}
249
250impl<'s> Encode for PipelineQuery<'s, '_> {
251 type Output = Vec<&'s [Column]>;
252
253 #[inline]
254 fn encode<const SYNC_MODE: bool>(self, buf_drv: &mut BytesMut) -> Result<Self::Output, Error> {
255 let Self { columns, buf, .. } = self;
256 buf_drv.extend_from_slice(buf);
257 Ok(columns)
258 }
259
260 #[inline(always)]
261 fn count_hint(&self) -> usize {
262 self.count
263 }
264}
265
266pub(crate) fn send_encode_query<S>(tx: &DriverTx, stmt: S) -> Result<(S::Output, Response), Error>
267where
268 S: Encode,
269{
270 let msg_count = stmt.count_hint();
271 tx.send(|buf| stmt.encode::<true>(buf), msg_count)
272}
273
274fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
275where
276 P: AsParams,
277{
278 let params = params.into_iter();
279 if params.len() != types.len() {
280 return Err(Error::from(InvalidParamCount {
281 expected: types.len(),
282 params: params.len(),
283 }));
284 }
285
286 let params = params.zip(types).collect::<Vec<_>>();
287
288 frontend::bind(
289 portal,
290 stmt,
291 params.iter().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
292 params.iter(),
293 |(param, ty), buf| {
294 param
295 .borrow_to_sql()
296 .to_sql_checked(ty, buf)
297 .map(|is_null| match is_null {
298 IsNull::No => postgres_protocol::IsNull::No,
299 IsNull::Yes => postgres_protocol::IsNull::Yes,
300 })
301 },
302 Some(1),
303 buf,
304 )
305 .map_err(|e| match e {
306 frontend::BindError::Conversion(e) => Error::from(e),
307 frontend::BindError::Serialization(e) => Error::from(e),
308 })
309}