1use core::{
2 future::Future,
3 mem,
4 num::NonZeroUsize,
5 ops::Deref,
6 pin::Pin,
7 task::{Context, Poll},
8};
9
10use lru::LruCache;
11use tokio::sync::SemaphorePermit;
12use xitca_io::bytes::BytesMut;
13
14use crate::{
15 BoxedFuture,
16 client::{Client, ClientBorrow, ClientBorrowMut},
17 copy::{r#Copy, CopyIn, CopyOut},
18 driver::codec::AsParams,
19 error::Error,
20 execute::Execute,
21 query::{RowAffected, RowStreamOwned},
22 session::Session,
23 statement::{Statement, StatementNamed, StatementQuery},
24 transaction::{Transaction, TransactionBuilder},
25};
26
27use super::Pool;
28
29pub struct PoolConnection<'a> {
59 pub(super) pool: &'a Pool,
60 pub(super) conn: Option<PoolClient>,
61 pub(super) _permit: SemaphorePermit<'a>,
62}
63
64impl<'a> PoolConnection<'a> {
65 #[inline]
67 pub fn transaction(&mut self) -> impl Future<Output = Result<Transaction<'_, Self>, Error>> + Send {
68 TransactionBuilder::new().begin(self)
69 }
70
71 #[inline]
73 pub fn transaction_owned(self) -> impl Future<Output = Result<Transaction<'a, Self>, Error>> + Send {
74 TransactionBuilder::new().begin_owned(self)
75 }
76
77 #[inline]
79 pub fn copy_in(&mut self, stmt: &Statement) -> impl Future<Output = Result<CopyIn<'_, Self>, Error>> + Send {
80 CopyIn::new(self, stmt)
81 }
82
83 #[inline]
85 pub async fn copy_out(&self, stmt: &Statement) -> Result<CopyOut, Error> {
86 CopyOut::new(self, stmt).await
87 }
88
89 #[inline(always)]
142 pub fn consume(self) -> Self {
143 self
144 }
145
146 pub fn cancel_token(&self) -> Session {
148 self.conn().client.cancel_token()
149 }
150
151 fn insert_cache<'c>(cache: &'c mut Cache, cli: &Client, named: &str, stmt: Statement) -> &'c CachedStatement {
152 if let Some((_, stmt)) = cache.push(Box::from(named), CachedStatement { stmt }) {
153 drop(stmt.stmt.into_guarded(&cli));
154 }
155 cache.peek_mru().unwrap().1
156 }
157
158 fn conn(&self) -> &PoolClient {
159 self.conn.as_ref().unwrap()
160 }
161
162 fn conn_mut(&mut self) -> &mut PoolClient {
163 self.conn.as_mut().unwrap()
164 }
165}
166
167impl ClientBorrow for PoolConnection<'_> {
168 #[inline]
169 fn borrow_cli_ref(&self) -> &Client {
170 &self.conn().client
171 }
172}
173
174impl ClientBorrowMut for PoolConnection<'_> {
175 #[inline]
176 fn borrow_cli_mut(&mut self) -> &mut Client {
177 &mut self.conn_mut().client
178 }
179}
180
181impl r#Copy for PoolConnection<'_> {
182 #[inline]
183 fn send_one_way<F>(&self, func: F) -> Result<(), Error>
184 where
185 F: FnOnce(&mut BytesMut) -> Result<(), Error>,
186 {
187 self.conn().client.send_one_way(func)
188 }
189}
190
191impl Drop for PoolConnection<'_> {
192 fn drop(&mut self) {
193 let conn = self.conn.take().unwrap();
194 self.pool.conn.lock().unwrap().push_back(conn);
195 }
196}
197
198pub struct CachedStatement {
203 stmt: Statement,
204}
205
206impl Clone for CachedStatement {
207 fn clone(&self) -> Self {
208 Self {
209 stmt: self.stmt.duplicate(),
210 }
211 }
212}
213
214impl Deref for CachedStatement {
215 type Target = Statement;
216
217 fn deref(&self) -> &Self::Target {
218 &self.stmt
219 }
220}
221
222pub(super) struct PoolClient {
223 client: Client,
224 cache: Cache,
225}
226
227impl PoolClient {
228 pub(super) fn closed(&self) -> bool {
229 self.client.closed()
230 }
231}
232
233type Cache = LruCache<Box<str>, CachedStatement>;
234
235impl PoolClient {
236 pub(super) fn new(client: Client, cap: NonZeroUsize) -> Self {
237 Self {
238 client,
239 cache: LruCache::new(cap),
240 }
241 }
242}
243
244impl<'c, E> Execute<&'c PoolConnection<'_>> for E
245where
246 E: Execute<&'c Client>,
247{
248 type ExecuteOutput = E::ExecuteOutput;
249 type QueryOutput = E::QueryOutput;
250
251 #[inline]
252 fn execute(self, cli: &'c PoolConnection<'_>) -> Self::ExecuteOutput {
253 E::execute(self, cli.borrow_cli_ref())
254 }
255
256 #[inline]
257 fn query(self, cli: &'c PoolConnection<'_>) -> Self::QueryOutput {
258 E::query(self, cli.borrow_cli_ref())
259 }
260}
261
262impl<'c, 's> Execute<&'c mut PoolConnection<'_>> for StatementNamed<'s>
263where
264 's: 'c,
265{
266 type ExecuteOutput = StatementCacheFuture<'c>;
267 type QueryOutput = Self::ExecuteOutput;
268
269 fn execute(self, cli: &'c mut PoolConnection) -> Self::ExecuteOutput {
270 match cli.conn_mut().cache.get(self.stmt) {
271 Some(stmt) => StatementCacheFuture::Cached(stmt.clone()),
272 None => StatementCacheFuture::Prepared(Box::pin(async move {
273 let conn = cli.conn_mut();
274 let name = self.stmt;
275 let stmt = self.execute(&conn.client).await?.leak();
276 Ok(PoolConnection::insert_cache(&mut conn.cache, &conn.client, name, stmt).clone())
277 })),
278 }
279 }
280
281 #[inline]
282 fn query(self, cli: &'c mut PoolConnection) -> Self::QueryOutput {
283 self.execute(cli)
284 }
285}
286
287pub enum StatementCacheFuture<'c> {
288 Cached(CachedStatement),
289 Prepared(BoxedFuture<'c, Result<CachedStatement, Error>>),
290 Done,
291}
292
293impl Future for StatementCacheFuture<'_> {
294 type Output = Result<CachedStatement, Error>;
295
296 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
297 let this = self.get_mut();
298 match mem::replace(this, Self::Done) {
299 Self::Cached(stmt) => Poll::Ready(Ok(stmt)),
300 Self::Prepared(mut fut) => {
301 let res = fut.as_mut().poll(cx);
302 if res.is_pending() {
303 drop(mem::replace(this, Self::Prepared(fut)));
304 }
305 res
306 }
307 Self::Done => panic!("StatementCacheFuture polled after finish"),
308 }
309 }
310}
311
312#[cfg(not(feature = "nightly"))]
313impl<'c, 's, P> Execute<&'c mut PoolConnection<'_>> for StatementQuery<'s, P>
314where
315 P: AsParams + Send + 'c,
316 's: 'c,
317{
318 type ExecuteOutput = BoxedFuture<'c, Result<RowAffected, Error>>;
319 type QueryOutput = BoxedFuture<'c, Result<RowStreamOwned, Error>>;
320
321 fn execute(self, conn: &'c mut PoolConnection<'_>) -> Self::ExecuteOutput {
322 Box::pin(async move {
323 let StatementQuery { stmt, types, params } = self;
324
325 let conn = conn.conn_mut();
326
327 let stmt = match conn.cache.get(stmt) {
328 Some(stmt) => stmt,
329 None => {
330 let prepared_stmt = Statement::named(stmt, types).execute(&conn.client).await?.leak();
331 PoolConnection::insert_cache(&mut conn.cache, &conn.client, stmt, prepared_stmt)
332 }
333 };
334
335 stmt.bind(params).query(&conn.client).await.map(RowAffected::from)
336 })
337 }
338
339 fn query(self, conn: &'c mut PoolConnection<'_>) -> Self::QueryOutput {
340 Box::pin(async move {
341 let StatementQuery { stmt, types, params } = self;
342
343 let conn = conn.conn_mut();
344
345 let stmt = match conn.cache.get(stmt) {
346 Some(stmt) => stmt,
347 None => {
348 let prepared_stmt = Statement::named(stmt, types).execute(&conn.client).await?.leak();
349 PoolConnection::insert_cache(&mut conn.cache, &conn.client, stmt, prepared_stmt)
350 }
351 };
352
353 stmt.bind(params).into_owned().query(&conn.client).await
354 })
355 }
356}
357
358#[cfg(feature = "nightly")]
359impl<'c, 's, 'p, P> Execute<&'c mut PoolConnection<'p>> for StatementQuery<'s, P>
360where
361 P: AsParams + Send + 'c,
362 's: 'c,
363 'p: 'c,
364{
365 type ExecuteOutput = impl Future<Output = Result<RowAffected, Error>> + Send + 'c;
366 type QueryOutput = impl Future<Output = Result<RowStreamOwned, Error>> + Send + 'c;
367
368 fn execute(self, conn: &'c mut PoolConnection<'p>) -> Self::ExecuteOutput {
369 async move {
370 let StatementQuery { stmt, types, params } = self;
371
372 let conn = conn.conn_mut();
373
374 let stmt = match conn.cache.get(stmt) {
375 Some(stmt) => stmt,
376 None => {
377 let prepared_stmt = Statement::named(stmt, types).execute(&conn.client).await?.leak();
378 PoolConnection::insert_cache(&mut conn.cache, &conn.client, stmt, prepared_stmt)
379 }
380 };
381
382 stmt.bind(params).query(&conn.client).await.map(RowAffected::from)
383 }
384 }
385
386 fn query(self, conn: &'c mut PoolConnection<'p>) -> Self::QueryOutput {
387 async move {
388 let StatementQuery { stmt, types, params } = self;
389
390 let conn = conn.conn_mut();
391
392 let stmt = match conn.cache.get(stmt) {
393 Some(stmt) => stmt,
394 None => {
395 let prepared_stmt = Statement::named(stmt, types).execute(&conn.client).await?.leak();
396 PoolConnection::insert_cache(&mut conn.cache, &conn.client, stmt, prepared_stmt)
397 }
398 };
399
400 stmt.bind(params).into_owned().query(&conn.client).await
401 }
402 }
403}