xitca_postgres/
pool.rs

1use core::{
2    future::Future,
3    mem,
4    ops::Deref,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use std::{
10    collections::{HashMap, VecDeque},
11    sync::Mutex,
12};
13
14use tokio::sync::{Semaphore, SemaphorePermit};
15use xitca_io::{bytes::BytesMut, io::AsyncIo};
16
17use super::{
18    BoxedFuture, Postgres,
19    client::{Client, ClientBorrowMut},
20    config::Config,
21    copy::{r#Copy, CopyIn, CopyOut},
22    driver::{
23        Driver,
24        codec::{AsParams, Response, encode::Encode},
25        generic::GenericDriver,
26    },
27    error::Error,
28    execute::Execute,
29    iter::AsyncLendingIterator,
30    prepare::Prepare,
31    query::{Query, RowAffected, RowStreamOwned},
32    session::Session,
33    statement::{Statement, StatementNamed, StatementQuery},
34    transaction::{Transaction, TransactionBuilder},
35    types::{Oid, Type},
36};
37
38/// builder type for connection pool
39pub struct PoolBuilder {
40    config: Result<Config, Error>,
41    capacity: usize,
42}
43
44impl PoolBuilder {
45    /// set capacity. pool would spawn up to amount of capacity concurrent connections to database.
46    ///
47    /// # Default
48    /// capacity default to 1
49    pub fn capacity(mut self, cap: usize) -> Self {
50        self.capacity = cap;
51        self
52    }
53
54    /// try convert builder to a connection pool instance.
55    pub fn build(self) -> Result<Pool, Error> {
56        let config = self.config?;
57
58        Ok(Pool {
59            conn: Mutex::new(VecDeque::with_capacity(self.capacity)),
60            permits: Semaphore::new(self.capacity),
61            config: Box::new(config),
62        })
63    }
64}
65
66/// connection pool for a set of connections to database.
67pub struct Pool {
68    conn: Mutex<VecDeque<PoolClient>>,
69    permits: Semaphore,
70    config: Box<Config>,
71}
72
73impl Pool {
74    /// start a builder of pool where it's behavior can be configured.
75    pub fn builder<C>(cfg: C) -> PoolBuilder
76    where
77        Config: TryFrom<C>,
78        Error: From<<Config as TryFrom<C>>::Error>,
79    {
80        PoolBuilder {
81            config: cfg.try_into().map_err(Into::into),
82            capacity: 1,
83        }
84    }
85
86    /// try to get a connection from pool.
87    /// when pool is empty it will try to spawn new connection to database and if the process failed the outcome will
88    /// return as [`Error`]
89    pub async fn get(&self) -> Result<PoolConnection<'_>, Error> {
90        let _permit = self.permits.acquire().await.expect("Semaphore must not be closed");
91        let conn = self.conn.lock().unwrap().pop_front();
92        let conn = match conn {
93            Some(conn) if !conn.client.closed() => conn,
94            _ => self.connect().await?,
95        };
96        Ok(PoolConnection {
97            pool: self,
98            conn: Some(conn),
99            _permit,
100        })
101    }
102
103    #[cold]
104    #[inline(never)]
105    fn connect(&self) -> BoxedFuture<'_, Result<PoolClient, Error>> {
106        Box::pin(async move {
107            let (client, driver) = Postgres::new(Clone::clone(&*self.config)).connect().await?;
108            match driver {
109                Driver::Tcp(drv) => {
110                    #[cfg(feature = "io-uring")]
111                    {
112                        drive_uring(drv)
113                    }
114
115                    #[cfg(not(feature = "io-uring"))]
116                    {
117                        drive(drv)
118                    }
119                }
120                Driver::Dynamic(drv) => drive(drv),
121                #[cfg(feature = "tls")]
122                Driver::Tls(drv) => drive(drv),
123                #[cfg(unix)]
124                Driver::Unix(drv) => drive(drv),
125                #[cfg(all(unix, feature = "tls"))]
126                Driver::UnixTls(drv) => drive(drv),
127                #[cfg(feature = "quic")]
128                Driver::Quic(drv) => drive(drv),
129            };
130            Ok(PoolClient::new(client))
131        })
132    }
133}
134
135fn drive(mut drv: GenericDriver<impl AsyncIo + Send + 'static>) {
136    tokio::task::spawn(async move {
137        while drv.try_next().await?.is_some() {
138            // TODO: add notify listen callback to Pool
139        }
140        Ok::<_, Error>(())
141    });
142}
143
144#[cfg(feature = "io-uring")]
145fn drive_uring(drv: GenericDriver<xitca_io::net::TcpStream>) {
146    use core::{async_iter::AsyncIterator, future::poll_fn, pin::pin};
147
148    tokio::task::spawn_local(async move {
149        let mut iter = pin!(crate::driver::io_uring::UringDriver::from_tcp(drv).into_iter());
150        while let Some(res) = poll_fn(|cx| iter.as_mut().poll_next(cx)).await {
151            let _ = res?;
152        }
153        Ok::<_, Error>(())
154    });
155}
156
157/// a RAII type for connection. it manages the lifetime of connection and it's [`Statement`] cache.
158/// a set of public is exposed to interact with them.
159///
160/// # Caching
161/// PoolConnection contains cache set of [`Statement`] to speed up regular used sql queries. when calling
162/// [`Execute::execute`] on a [`StatementNamed`] with &[`PoolConnection`] the pool connection does nothing
163/// special and function the same as a regular [`Client`]. In order to utilize the cache caller must execute
164/// the named statement with &mut [`PoolConnection`]. With a mutable reference of pool connection it will do
165/// local cache look up for statement and hand out one in the type of [`Statement`] if any found. If no
166/// copy is found in the cache pool connection will prepare a new statement and insert it into the cache.
167/// ## Examples
168/// ```
169/// # use xitca_postgres::{pool::Pool, Execute, Error, Statement};
170/// # async fn cached(pool: &Pool) -> Result<(), Error> {
171/// let mut conn = pool.get().await?;
172/// // prepare a statement without caching
173/// Statement::named("SELECT 1", &[]).execute(&conn).await?;
174/// // prepare a statement with caching from conn.
175/// Statement::named("SELECT 1", &[]).execute(&mut conn).await?;
176/// # Ok(())
177/// # }
178/// ```
179///
180/// * When to use caching or not:
181/// - query statement repeatedly called intensely can benefit from cache.
182/// - query statement with low latency requirement can benefit from upfront cache.
183/// - rare query statement can benefit from no caching by reduce resource usage from the server side. For low
184///   latency of rare query consider use [`StatementNamed::bind`] as alternative.
185pub struct PoolConnection<'a> {
186    pool: &'a Pool,
187    conn: Option<PoolClient>,
188    _permit: SemaphorePermit<'a>,
189}
190
191impl PoolConnection<'_> {
192    /// function the same as [`Client::transaction`]
193    #[inline]
194    pub fn transaction(&mut self) -> impl Future<Output = Result<Transaction<&mut Self>, Error>> + Send {
195        TransactionBuilder::new().begin(self)
196    }
197
198    /// owned version of [`PoolConnection::transaction`]
199    #[inline]
200    pub fn transaction_owned(self) -> impl Future<Output = Result<Transaction<Self>, Error>> + Send {
201        TransactionBuilder::new().begin(self)
202    }
203
204    /// function the same as [`Client::copy_in`]
205    #[inline]
206    pub fn copy_in(&mut self, stmt: &Statement) -> impl Future<Output = Result<CopyIn<'_, Self>, Error>> + Send {
207        CopyIn::new(self, stmt)
208    }
209
210    /// function the same as [`Client::copy_out`]
211    #[inline]
212    pub async fn copy_out(&self, stmt: &Statement) -> Result<CopyOut, Error> {
213        CopyOut::new(self, stmt).await
214    }
215
216    /// a shortcut to move and take ownership of self.
217    /// an important behavior of [PoolConnection] is it supports pipelining. eagerly drop it after usage can
218    /// contribute to more queries being pipelined. especially before any `await` point.
219    ///
220    /// # Examples
221    /// ```rust
222    /// use xitca_postgres::{pool::Pool, Error, Execute};
223    ///
224    /// async fn example(pool: &Pool) -> Result<(), Error> {
225    ///     // get a connection from pool and start a query.
226    ///     let mut conn = pool.get().await?;
227    ///
228    ///     "SELECT *".execute(&conn).await?;
229    ///     
230    ///     // connection is kept across await point. making it unusable to other concurrent
231    ///     // callers to example function. and no pipelining will happen until it's released.
232    ///     conn = conn;
233    ///
234    ///     // start another query but this time consume ownership and when res is returned
235    ///     // connection is dropped and went back to pool.
236    ///     let res = "SELECT *".execute(&conn.consume());
237    ///
238    ///     // connection can't be used anymore in this scope but other concurrent callers
239    ///     // to example function is able to use it and if they follow the same calling
240    ///     // convention pipelining could happen and reduce syscall overhead.
241    ///
242    ///     // let res = "SELECT *".execute(&conn);
243    ///
244    ///     // without connection the response can still be collected asynchronously
245    ///     res.await?;
246    ///
247    ///     // therefore a good calling convention for independent queries could be:
248    ///     let conn = pool.get().await?;
249    ///     let res1 = "SELECT *".execute(&conn);
250    ///     let res2 = "SELECT *".execute(&conn);
251    ///     let res3 = "SELECT *".execute(&conn.consume());
252    ///
253    ///     // all three queries can be pipelined into a single write syscall. and possibly
254    ///     // even more can be pipelined after conn.consume() is called if there are concurrent
255    ///     // callers use the same connection.
256    ///     
257    ///     res1.await?;
258    ///     res2.await?;
259    ///     res3.await?;
260    ///
261    ///     // it should be noted that pipelining is an optional crate feature for some potential
262    ///     // performance gain.
263    ///     // it's totally fine to ignore and use the apis normally with zero thought put into it.
264    ///
265    ///     Ok(())
266    /// }
267    /// ```
268    #[inline(always)]
269    pub fn consume(self) -> Self {
270        self
271    }
272
273    /// function the same as [`Client::cancel_token`]
274    pub fn cancel_token(&self) -> Session {
275        self.conn().client.cancel_token()
276    }
277
278    fn insert_cache(&mut self, named: &str, stmt: Statement) -> &CachedStatement {
279        self.conn_mut()
280            .statements
281            .entry(Box::from(named))
282            .or_insert(CachedStatement { stmt })
283    }
284
285    fn conn(&self) -> &PoolClient {
286        self.conn.as_ref().unwrap()
287    }
288
289    fn conn_mut(&mut self) -> &mut PoolClient {
290        self.conn.as_mut().unwrap()
291    }
292}
293
294impl ClientBorrowMut for PoolConnection<'_> {
295    #[inline]
296    fn _borrow_mut(&mut self) -> &mut Client {
297        &mut self.conn_mut().client
298    }
299}
300
301impl Prepare for PoolConnection<'_> {
302    #[inline]
303    async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
304        self.conn().client._get_type(oid).await
305    }
306
307    #[inline]
308    fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
309        self.conn().client._get_type_blocking(oid)
310    }
311}
312
313impl Query for PoolConnection<'_> {
314    #[inline]
315    fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
316    where
317        S: Encode,
318    {
319        self.conn().client._send_encode_query(stmt)
320    }
321}
322
323impl r#Copy for PoolConnection<'_> {
324    #[inline]
325    fn send_one_way<F>(&self, func: F) -> Result<(), Error>
326    where
327        F: FnOnce(&mut BytesMut) -> Result<(), Error>,
328    {
329        self.conn().client.send_one_way(func)
330    }
331}
332
333impl Drop for PoolConnection<'_> {
334    fn drop(&mut self) {
335        let conn = self.conn.take().unwrap();
336        self.pool.conn.lock().unwrap().push_back(conn);
337    }
338}
339
340/// Cached [`Statement`] from [`PoolConnection`]
341///
342/// Can be used for the same purpose without the ability to cancel actively
343/// It's lifetime is managed by [`PoolConnection`]
344pub struct CachedStatement {
345    stmt: Statement,
346}
347
348impl Clone for CachedStatement {
349    fn clone(&self) -> Self {
350        Self {
351            stmt: self.stmt.duplicate(),
352        }
353    }
354}
355
356impl Deref for CachedStatement {
357    type Target = Statement;
358
359    fn deref(&self) -> &Self::Target {
360        &self.stmt
361    }
362}
363
364struct PoolClient {
365    client: Client,
366    statements: HashMap<Box<str>, CachedStatement>,
367}
368
369impl PoolClient {
370    fn new(client: Client) -> Self {
371        Self {
372            client,
373            statements: HashMap::new(),
374        }
375    }
376}
377
378impl<'c, 's> Execute<&'c mut PoolConnection<'_>> for StatementNamed<'s>
379where
380    's: 'c,
381{
382    type ExecuteOutput = StatementCacheFuture<'c>;
383    type QueryOutput = Self::ExecuteOutput;
384
385    fn execute(self, cli: &'c mut PoolConnection) -> Self::ExecuteOutput {
386        match cli.conn().statements.get(self.stmt) {
387            Some(stmt) => StatementCacheFuture::Cached(stmt.clone()),
388            None => StatementCacheFuture::Prepared(Box::pin(async move {
389                let name = self.stmt;
390                let stmt = self.execute(&*cli).await?.leak();
391                Ok(cli.insert_cache(name, stmt).clone())
392            })),
393        }
394    }
395
396    #[inline]
397    fn query(self, cli: &'c mut PoolConnection) -> Self::QueryOutput {
398        self.execute(cli)
399    }
400}
401
402#[cfg(not(feature = "nightly"))]
403impl<'c, 's, P> Execute<&'c mut PoolConnection<'_>> for StatementQuery<'s, P>
404where
405    P: AsParams + Send + 'c,
406    's: 'c,
407{
408    type ExecuteOutput = BoxedFuture<'c, Result<RowAffected, Error>>;
409    type QueryOutput = BoxedFuture<'c, Result<RowStreamOwned, Error>>;
410
411    fn execute(self, conn: &'c mut PoolConnection<'_>) -> Self::ExecuteOutput {
412        Box::pin(async move {
413            let StatementQuery { stmt, types, params } = self;
414
415            let stmt = match conn.conn().statements.get(stmt) {
416                Some(stmt) => stmt,
417                None => {
418                    let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
419                    conn.insert_cache(stmt, prepared_stmt);
420                    conn.conn().statements.get(stmt).unwrap()
421                }
422            };
423
424            stmt.bind(params).query(conn).await.map(RowAffected::from)
425        })
426    }
427
428    fn query(self, conn: &'c mut PoolConnection<'_>) -> Self::QueryOutput {
429        Box::pin(async move {
430            let StatementQuery { stmt, types, params } = self;
431
432            let stmt = match conn.conn().statements.get(stmt) {
433                Some(stmt) => stmt,
434                None => {
435                    let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
436                    conn.insert_cache(stmt, prepared_stmt);
437                    conn.conn().statements.get(stmt).unwrap()
438                }
439            };
440
441            stmt.bind(params).into_owned().query(conn).await
442        })
443    }
444}
445
446#[cfg(feature = "nightly")]
447impl<'c, 's, 'p, P> Execute<&'c mut PoolConnection<'p>> for StatementQuery<'s, P>
448where
449    P: AsParams + Send + 'c,
450    's: 'c,
451    'p: 'c,
452{
453    type ExecuteOutput = impl Future<Output = Result<RowAffected, Error>> + Send + 'c;
454    type QueryOutput = impl Future<Output = Result<RowStreamOwned, Error>> + Send + 'c;
455
456    fn execute(self, conn: &'c mut PoolConnection<'p>) -> Self::ExecuteOutput {
457        async move {
458            let StatementQuery { stmt, types, params } = self;
459
460            let stmt = match conn.conn().statements.get(stmt) {
461                Some(stmt) => stmt,
462                None => {
463                    let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
464                    conn.insert_cache(stmt, prepared_stmt);
465                    conn.conn().statements.get(stmt).unwrap()
466                }
467            };
468
469            stmt.bind(params).query(conn).await.map(RowAffected::from)
470        }
471    }
472
473    fn query(self, conn: &'c mut PoolConnection<'p>) -> Self::QueryOutput {
474        async move {
475            let StatementQuery { stmt, types, params } = self;
476
477            let stmt = match conn.conn().statements.get(stmt) {
478                Some(stmt) => stmt,
479                None => {
480                    let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
481                    conn.insert_cache(stmt, prepared_stmt);
482                    conn.conn().statements.get(stmt).unwrap()
483                }
484            };
485
486            stmt.bind(params).into_owned().query(conn).await
487        }
488    }
489}
490
491// TODO: unbox returned futures when type alias is allowed in associated type.
492#[cfg(not(feature = "nightly"))]
493impl<'c, 's, P> Execute<&'c Pool> for StatementQuery<'s, P>
494where
495    P: AsParams + Send + 'c,
496    's: 'c,
497{
498    type ExecuteOutput = BoxedFuture<'c, Result<u64, Error>>;
499    type QueryOutput = BoxedFuture<'c, Result<RowStreamOwned, Error>>;
500
501    #[inline]
502    fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
503        Box::pin(async {
504            {
505                let mut conn = pool.get().await?;
506                self.execute(&mut conn).await?
507            }
508            // return connection to pool before await on execution future
509            .await
510        })
511    }
512
513    #[inline]
514    fn query(self, pool: &'c Pool) -> Self::QueryOutput {
515        Box::pin(async {
516            let mut conn = pool.get().await?;
517            self.query(&mut conn).await
518        })
519    }
520}
521
522#[cfg(feature = "nightly")]
523impl<'c, 's, P> Execute<&'c Pool> for StatementQuery<'s, P>
524where
525    P: AsParams + Send + 'c,
526    's: 'c,
527{
528    type ExecuteOutput = impl Future<Output = Result<u64, Error>> + Send + 'c;
529    type QueryOutput = impl Future<Output = Result<RowStreamOwned, Error>> + Send + 'c;
530
531    #[inline]
532    fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
533        async {
534            {
535                let mut conn = pool.get().await?;
536                self.execute(&mut conn).await?
537            }
538            // return connection to pool before await on execution future
539            .await
540        }
541    }
542
543    #[inline]
544    fn query(self, pool: &'c Pool) -> Self::QueryOutput {
545        async {
546            let mut conn = pool.get().await?;
547            self.query(&mut conn).await
548        }
549    }
550}
551
552// TODO: unbox returned futures when type alias is allowed in associated type.
553#[cfg(not(feature = "nightly"))]
554impl<'c, 's, I, P> Execute<&'c Pool> for I
555where
556    I: IntoIterator,
557    I::IntoIter: Iterator<Item = StatementQuery<'s, P>> + Send + 'c,
558    P: AsParams + Send + 'c,
559    's: 'c,
560{
561    type ExecuteOutput = BoxedFuture<'c, Result<u64, Error>>;
562    type QueryOutput = BoxedFuture<'c, Result<Vec<RowStreamOwned>, Error>>;
563
564    #[inline]
565    fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
566        Box::pin(execute_iter_with_pool(self.into_iter(), pool))
567    }
568
569    #[inline]
570    fn query(self, pool: &'c Pool) -> Self::QueryOutput {
571        Box::pin(query_iter_with_pool(self.into_iter(), pool))
572    }
573}
574
575#[cfg(feature = "nightly")]
576impl<'c, 's, I, P> Execute<&'c Pool> for I
577where
578    I: IntoIterator,
579    I::IntoIter: Iterator<Item = StatementQuery<'s, P>> + Send + 'c,
580    P: AsParams + Send + 'c,
581    's: 'c,
582{
583    type ExecuteOutput = impl Future<Output = Result<u64, Error>> + Send + 'c;
584    type QueryOutput = impl Future<Output = Result<Vec<RowStreamOwned>, Error>> + Send + 'c;
585
586    #[inline]
587    fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
588        execute_iter_with_pool(self.into_iter(), pool)
589    }
590
591    #[inline]
592    fn query(self, pool: &'c Pool) -> Self::QueryOutput {
593        query_iter_with_pool(self.into_iter(), pool)
594    }
595}
596
597async fn execute_iter_with_pool<P>(
598    iter: impl Iterator<Item = StatementQuery<'_, P>> + Send,
599    pool: &Pool,
600) -> Result<u64, Error>
601where
602    P: AsParams + Send,
603{
604    let mut res = Vec::with_capacity(iter.size_hint().0);
605
606    {
607        let mut conn = pool.get().await?;
608
609        for stmt in iter {
610            let fut = stmt.execute(&mut conn).await?;
611            res.push(fut);
612        }
613    }
614
615    let mut num = 0;
616
617    for res in res {
618        num += res.await?;
619    }
620
621    Ok(num)
622}
623
624async fn query_iter_with_pool<P>(
625    iter: impl Iterator<Item = StatementQuery<'_, P>> + Send,
626    pool: &Pool,
627) -> Result<Vec<RowStreamOwned>, Error>
628where
629    P: AsParams + Send,
630{
631    let mut res = Vec::with_capacity(iter.size_hint().0);
632
633    let mut conn = pool.get().await?;
634
635    for stmt in iter {
636        let stream = stmt.query(&mut conn).await?;
637        res.push(stream);
638    }
639
640    Ok(res)
641}
642
643pub enum StatementCacheFuture<'c> {
644    Cached(CachedStatement),
645    Prepared(BoxedFuture<'c, Result<CachedStatement, Error>>),
646    Done,
647}
648
649impl Future for StatementCacheFuture<'_> {
650    type Output = Result<CachedStatement, Error>;
651
652    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
653        let this = self.get_mut();
654        match mem::replace(this, Self::Done) {
655            Self::Cached(stmt) => Poll::Ready(Ok(stmt)),
656            Self::Prepared(mut fut) => {
657                let res = fut.as_mut().poll(cx);
658                if res.is_pending() {
659                    drop(mem::replace(this, Self::Prepared(fut)));
660                }
661                res
662            }
663            Self::Done => panic!("StatementCacheFuture polled after finish"),
664        }
665    }
666}
667
668#[cfg(not(feature = "io-uring"))]
669#[cfg(test)]
670mod test {
671    use super::*;
672
673    #[tokio::test]
674    async fn pool() {
675        let pool = Pool::builder("postgres://postgres:postgres@localhost:5432")
676            .build()
677            .unwrap();
678
679        {
680            let mut conn = pool.get().await.unwrap();
681
682            let stmt = Statement::named("SELECT 1", &[]).execute(&mut conn).await.unwrap();
683            stmt.execute(&conn.consume()).await.unwrap();
684
685            let num = Statement::named("SELECT 1", &[])
686                .bind_none()
687                .query(&pool)
688                .await
689                .unwrap()
690                .try_next()
691                .await
692                .unwrap()
693                .unwrap()
694                .get::<i32>(0);
695
696            assert_eq!(num, 1);
697        }
698
699        let res = [
700            Statement::named("SELECT 1", &[]).bind_none(),
701            Statement::named("SELECT 1", &[]).bind_none(),
702        ]
703        .query(&pool)
704        .await
705        .unwrap();
706
707        for mut res in res {
708            let num = res.try_next().await.unwrap().unwrap().get::<i32>(0);
709            assert_eq!(num, 1);
710        }
711
712        let _ = vec![
713            Statement::named("SELECT 1", &[]).bind_dyn(&[&1]),
714            Statement::named("SELECT 1", &[]).bind_dyn(&[&"123"]),
715            Statement::named("SELECT 1", &[]).bind_dyn(&[&String::new()]),
716        ]
717        .query(&pool)
718        .await;
719    }
720}