1use std::sync::Arc;
2
3use xitca_io::bytes::BytesMut;
4
5use crate::{
6 client::ClientBorrow,
7 column::Column,
8 error::{Error, InvalidParamCount},
9 protocol::{self, message::frontend},
10 statement::{
11 Statement, StatementCreate, StatementCreateBlocking, StatementPreparedCancel, StatementPreparedQuery,
12 StatementPreparedQueryOwned, StatementQuery, StatementSingleRTTQueryWithCli,
13 },
14 types::{BorrowToSql, IsNull, Type},
15};
16
17use super::{
18 AsParams,
19 response::{
20 IntoResponse, IntoRowStreamGuard, NoOpIntoRowStream, StatementCreateResponse, StatementCreateResponseBlocking,
21 },
22 sealed,
23};
24
25#[diagnostic::on_unimplemented(
28 message = "`{Self}` does not impl Encode trait",
29 label = "query statement argument must be types implement Encode trait",
30 note = "consider using the types listed below that implementing Encode trait"
31)]
32pub trait Encode: sealed::Sealed + Sized {
33 type Output: IntoResponse;
36
37 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error>;
38}
39
40impl sealed::Sealed for &str {}
41
42impl Encode for &str {
43 type Output = Vec<Column>;
44
45 #[inline]
46 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
47 frontend::query(self, buf)?;
48 Ok(Vec::new())
49 }
50}
51
52impl<C> sealed::Sealed for StatementCreate<'_, '_, C> {}
53
54impl<'c, C> Encode for StatementCreate<'_, 'c, C>
55where
56 C: ClientBorrow + Sync,
57{
58 type Output = StatementCreateResponse<'c, C>;
59
60 #[inline]
61 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
62 let Self { name, stmt, types, cli } = self;
63 encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponse { name, cli })
64 }
65}
66
67impl<C> sealed::Sealed for StatementCreateBlocking<'_, '_, C> {}
68
69impl<'c, C> Encode for StatementCreateBlocking<'_, 'c, C>
70where
71 C: ClientBorrow,
72{
73 type Output = StatementCreateResponseBlocking<'c, C>;
74
75 #[inline]
76 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
77 let Self { name, stmt, types, cli } = self;
78 encode_statement_create(&name, stmt, types, buf).map(|_| StatementCreateResponseBlocking { name, cli })
79 }
80}
81
82fn encode_statement_create(name: &str, stmt: &str, types: &[Type], buf: &mut BytesMut) -> Result<(), Error> {
83 frontend::parse(name, stmt, types.iter().map(Type::oid), buf)?;
84 frontend::describe(b'S', name, buf)?;
85 frontend::sync(buf);
86 Ok(())
87}
88
89impl sealed::Sealed for StatementPreparedCancel<'_> {}
90
91impl Encode for StatementPreparedCancel<'_> {
92 type Output = NoOpIntoRowStream;
93
94 #[inline]
95 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
96 let Self { name } = self;
97 frontend::close(b'S', name, buf)?;
98 frontend::sync(buf);
99 Ok(NoOpIntoRowStream)
100 }
101}
102
103impl<P> sealed::Sealed for StatementPreparedQuery<'_, P> {}
104
105impl<'s, P> Encode for StatementPreparedQuery<'s, P>
106where
107 P: AsParams,
108{
109 type Output = &'s [Column];
110
111 #[inline]
112 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
113 let Self { stmt, params } = self;
114 encode_stmt_query(stmt, params, buf).map(|_| stmt.columns())
115 }
116}
117
118impl<P> sealed::Sealed for StatementPreparedQueryOwned<'_, P> {}
119
120impl<'s, P> Encode for StatementPreparedQueryOwned<'s, P>
121where
122 P: AsParams,
123{
124 type Output = Arc<[Column]>;
125
126 #[inline]
127 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
128 let Self { stmt, params } = self;
129 encode_stmt_query(stmt, params, buf).map(|_| stmt.columns_owned())
130 }
131}
132
133fn encode_stmt_query<P>(stmt: &Statement, params: P, buf: &mut BytesMut) -> Result<(), Error>
134where
135 P: AsParams,
136{
137 encode_bind(stmt.name(), stmt.params(), params, "", buf)?;
138 frontend::execute("", 0, buf)?;
139 frontend::sync(buf);
140 Ok(())
141}
142
143impl<C, P> sealed::Sealed for StatementSingleRTTQueryWithCli<'_, '_, P, C> {}
144
145impl<'c, C, P> Encode for StatementSingleRTTQueryWithCli<'_, 'c, P, C>
146where
147 C: ClientBorrow,
148 P: AsParams,
149{
150 type Output = IntoRowStreamGuard<'c, C>;
151
152 #[inline]
153 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
154 let Self { query, cli } = self;
155 let StatementQuery { stmt, params, types } = query;
156 frontend::parse("", stmt, types.iter().map(Type::oid), buf)?;
157 encode_bind("", types, params, "", buf)?;
158 frontend::describe(b'S', "", buf)?;
159 frontend::execute("", 0, buf)?;
160 frontend::sync(buf);
161 Ok(IntoRowStreamGuard(cli))
162 }
163}
164
165pub(crate) struct PortalCreate<'a, P> {
166 pub(crate) name: &'a str,
167 pub(crate) stmt: &'a str,
168 pub(crate) types: &'a [Type],
169 pub(crate) params: P,
170}
171
172impl<P> sealed::Sealed for PortalCreate<'_, P> {}
173
174impl<P> Encode for PortalCreate<'_, P>
175where
176 P: AsParams,
177{
178 type Output = NoOpIntoRowStream;
179
180 #[inline]
181 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
182 let PortalCreate {
183 name,
184 stmt,
185 types,
186 params,
187 } = self;
188 encode_bind(stmt, types, params, name, buf)?;
189 frontend::sync(buf);
190 Ok(NoOpIntoRowStream)
191 }
192}
193
194pub(crate) struct PortalCancel<'a> {
195 pub(crate) name: &'a str,
196}
197
198impl sealed::Sealed for PortalCancel<'_> {}
199
200impl Encode for PortalCancel<'_> {
201 type Output = NoOpIntoRowStream;
202
203 #[inline]
204 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
205 frontend::close(b'P', self.name, buf)?;
206 frontend::sync(buf);
207 Ok(NoOpIntoRowStream)
208 }
209}
210
211pub struct PortalQuery<'a> {
212 pub(crate) name: &'a str,
213 pub(crate) columns: &'a [Column],
214 pub(crate) max_rows: i32,
215}
216
217impl sealed::Sealed for PortalQuery<'_> {}
218
219impl<'s> Encode for PortalQuery<'s> {
220 type Output = &'s [Column];
221
222 #[inline]
223 fn encode(self, buf: &mut BytesMut) -> Result<Self::Output, Error> {
224 let Self {
225 name,
226 max_rows,
227 columns,
228 } = self;
229 frontend::execute(name, max_rows, buf)?;
230 frontend::sync(buf);
231 Ok(columns)
232 }
233}
234
235fn encode_bind<P>(stmt: &str, types: &[Type], params: P, portal: &str, buf: &mut BytesMut) -> Result<(), Error>
236where
237 P: AsParams,
238{
239 let params = params.into_iter();
240 if params.len() != types.len() {
241 return Err(Error::from(InvalidParamCount {
242 expected: types.len(),
243 params: params.len(),
244 }));
245 }
246
247 let params = params.zip(types);
248
249 frontend::bind(
250 portal,
251 stmt,
252 params.clone().map(|(p, ty)| p.borrow_to_sql().encode_format(ty) as _),
253 params,
254 |(p, ty), buf| {
255 p.borrow_to_sql().to_sql_checked(ty, buf).map(|is_null| match is_null {
256 IsNull::No => protocol::IsNull::No,
257 IsNull::Yes => protocol::IsNull::Yes,
258 })
259 },
260 Some(1),
261 buf,
262 )
263 .map_err(|e| match e {
264 frontend::BindError::Conversion(e) => Error::from(e),
265 frontend::BindError::Serialization(e) => Error::from(e),
266 })
267}