1use std::sync::Arc;
2
3use postgres_protocol::message::frontend;
4use xitca_io::bytes::BytesMut;
5
6use crate::{
7 column::Column,
8 error::{Error, InvalidParamCount},
9 prepare::Prepare,
10 statement::{
11 Statement, StatementCreate, StatementCreateBlocking, StatementPreparedQuery, StatementPreparedQueryOwned,
12 StatementQuery, StatementSingleRTTQueryWithCli,
13 },
14 types::{BorrowToSql, IsNull, Type},
15};
16
17use super::{
18 AsParams, DriverTx, Response,
19 response::{
20 IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
21 },
22 sealed,
23};
24
25#[diagnostic::on_unimplemented(
28 message = "`{Self}` does not impl Encode trait",
29 label = "query statement argument must be types implement Encode trait",
30 note = "consider using the types listed below that implementing Encode trait"
31)]
32pub trait Encode: sealed::Sealed + Sized {
33 type Output: IntoResponse;
36
37 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
38}
39
40impl sealed::Sealed for &str {}
41
42impl Encode for &str {
43 type Output = Vec<Column>;
44
45 #[inline]
46 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
47 frontend::query(self, buf)?;
48 Ok(Vec::new())
49 }
50}
51
52impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}
53
54impl<'c, C> Encode for StatementCreate<'_, 'c, C>
55where
56 C: Prepare,
57{
58 type Output = StatementCreateResponse<'c, C>;
59
60 #[inline]
61 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
62 let Self { name, stmt, types, cli } = self;
63 encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
64 }
65}
66
67impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}
68
69impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
70where
71 C: Prepare,
72{
73 type Output = StatementCreateResponseBlocking<'c, C>;
74
75 #[inline]
76 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
77 let Self { name, stmt, types, cli } = self;
78 encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
79 }
80}
81
82fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
83 frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
84 frontend::describe(b'S', name, buf)?;
85 frontend::sync(buf);
86 Ok(())
87}
88
89pub(crate) struct StatementCancel<'a> {
90 pub(crate) name: &'a str,
91}
92
93impl sealed::Sealed for StatementCancel<'_> {}
94
95impl Encode for StatementCancel<'_> {
96 type Output = NoOpIntoRowStream;
97
98 #[inline]
99 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
100 let Self { name } = self;
101 frontend::close(b'S', name, buf)?;
102 frontend::sync(buf);
103 Ok(NoOpIntoRowStream)
104 }
105}
106
107impl<P> sealed::Sealed for StatementPreparedQuery<'_, P> {}
108
109impl<'s, P> Encode for StatementPreparedQuery<'s, P>
110where
111 P: AsParams,
112{
113 type Output = &'s [Column];
114
115 #[inline]
116 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
117 let Self { stmt, params } = self;
118 encode_stmt_query(stmt, params, buf).map(|_| stmt.columns())
119 }
120}
121
122impl<P> sealed::Sealed for StatementPreparedQueryOwned<'_, P> {}
123
124impl<'s, P> Encode for StatementPreparedQueryOwned<'s, P>
125where
126 P: AsParams,
127{
128 type Output = Arc<[Column]>;
129
130 #[inline]
131 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
132 let Self { stmt, params } = self;
133 encode_stmt_query(stmt, params, buf).map(|_| stmt.columns_owned())
134 }
135}
136
137fn encode_stmt_query<P>(stmt: &Statement, params: P, buf: &mut BytesMut) -> Result<(), Error>
138where
139 P: AsParams,
140{
141 encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
142 frontend::execute("", 0, buf)?;
143 frontend::sync(buf);
144 Ok(())
145}
146
147impl<C, P> sealed::Sealed for StatementSingleRTTQueryWithCli<'_, '_, P, C> {}
148
149impl<'c, C, P> Encode for StatementSingleRTTQueryWithCli<'_, 'c, P, C>
150where
151 C: Prepare,
152 P: AsParams,
153{
154 type Output = IntoRowStreamGuard<'c, C>;
155
156 #[inline]
157 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
158 let Self { query, cli } = self;
159 let StatementQuery { stmt, params, types } = query;
160 frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
161 encode_bind("", types, params, "", buf)?;
162 frontend::describe(b'S', "", buf)?;
163 frontend::execute("", 0, buf)?;
164 frontend::sync(buf);
165 Ok(IntoRowStreamGuard(cli))
166 }
167}
168
169pub(crate) struct PortalCreate<'a, P> {
170 pub(crate) name: &'a str,
171 pub(crate) stmt: &'a str,
172 pub(crate) types: &'a [Type],
173 pub(crate) params: P,
174}
175
176impl<P> sealed::Sealed for PortalCreate<'_, P> {}
177
178impl<P> Encode for PortalCreate<'_, P>
179where
180 P: AsParams,
181{
182 type Output = NoOpIntoRowStream;
183
184 #[inline]
185 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
186 let PortalCreate {
187 name,
188 stmt,
189 types,
190 params,
191 } = self;
192 encode_bind(stmt, types, params, name, buf)?;
193 frontend::sync(buf);
194 Ok(NoOpIntoRowStream)
195 }
196}
197
198pub(crate) struct PortalCancel<'a> {
199 pub(crate) name: &'a str,
200}
201
202impl sealed::Sealed for PortalCancel<'_> {}
203
204impl Encode for PortalCancel<'_> {
205 type Output = NoOpIntoRowStream;
206
207 #[inline]
208 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
209 frontend::close(b'P', self.name, buf)?;
210 frontend::sync(buf);
211 Ok(NoOpIntoRowStream)
212 }
213}
214
215pub struct PortalQuery<'a> {
216 pub(crate) name: &'a str,
217 pub(crate) columns: &'a [Column],
218 pub(crate) max_rows: i32,
219}
220
221impl sealed::Sealed for PortalQuery<'_> {}
222
223impl<'s> Encode for PortalQuery<'s> {
224 type Output = &'s [Column];
225
226 #[inline]
227 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
228 let Self {
229 name,
230 max_rows,
231 columns,
232 } = self;
233 frontend::execute(name, max_rows, buf)?;
234 frontend::sync(buf);
235 Ok(columns)
236 }
237}
238
239pub(crate) fn send_encode_query<S>(tx: &DriverTx, stmt: S) -> Result<(S::Output, Response), Error>
240where
241 S: Encode,
242{
243 tx.send(|buf| stmt.encode(buf))
244}
245
246fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
247where
248 P: AsParams,
249{
250 let params = params.into_iter();
251 if params.len() != types.len() {
252 return Err(Error::from(InvalidParamCount {
253 expected: types.len(),
254 params: params.len(),
255 }));
256 }
257
258 let params = params.zip(types);
259
260 frontend::bind(
261 portal,
262 stmt,
263 params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
264 params,
265 |(p, ty), buf| {
266 p.borrow_to_sql().to_sql_checked(ty, buf).map(|is_null| match is_null {
267 IsNull::No => postgres_protocol::IsNull::No,
268 IsNull::Yes => postgres_protocol::IsNull::Yes,
269 })
270 },
271 Some(1),
272 buf,
273 )
274 .map_err(|e| match e {
275 frontend::BindError::Conversion(e) => Error::from(e),
276 frontend::BindError::Serialization(e) => Error::from(e),
277 })
278}