1use core::{
2 future::Future,
3 mem,
4 pin::Pin,
5 task::{Context, Poll},
6};
7
8use std::{
9 collections::{HashMap, VecDeque},
10 sync::{Arc, Mutex},
11};
12
13use tokio::sync::{Semaphore, SemaphorePermit};
14use xitca_io::bytes::BytesMut;
15
16use super::{
17 client::{Client, ClientBorrowMut},
18 config::Config,
19 copy::{r#Copy, CopyIn, CopyOut},
20 driver::codec::{encode::Encode, Response},
21 error::Error,
22 execute::{Execute, ExecuteMut},
23 iter::AsyncLendingIterator,
24 prepare::Prepare,
25 query::Query,
26 session::Session,
27 statement::{Statement, StatementNamed},
28 transaction::Transaction,
29 types::{Oid, Type},
30 BoxedFuture, Postgres,
31};
32
33pub struct PoolBuilder {
35 config: Result<Config, Error>,
36 capacity: usize,
37}
38
39impl PoolBuilder {
40 pub fn capacity(mut self, cap: usize) -> Self {
45 self.capacity = cap;
46 self
47 }
48
49 pub fn build(self) -> Result<Pool, Error> {
51 let config = self.config?;
52
53 Ok(Pool {
54 conn: Mutex::new(VecDeque::with_capacity(self.capacity)),
55 permits: Semaphore::new(self.capacity),
56 config,
57 })
58 }
59}
60
61pub struct Pool {
63 conn: Mutex<VecDeque<PoolClient>>,
64 permits: Semaphore,
65 config: Config,
66}
67
68impl Pool {
69 pub fn builder<C>(cfg: C) -> PoolBuilder
71 where
72 Config: TryFrom<C>,
73 Error: From<<Config as TryFrom<C>>::Error>,
74 {
75 PoolBuilder {
76 config: cfg.try_into().map_err(Into::into),
77 capacity: 1,
78 }
79 }
80
81 pub async fn get(&self) -> Result<PoolConnection<'_>, Error> {
85 let _permit = self.permits.acquire().await.expect("Semaphore must not be closed");
86 let conn = self.conn.lock().unwrap().pop_front();
87 let conn = match conn {
88 Some(conn) => conn,
89 None => self.connect().await?,
90 };
91 Ok(PoolConnection {
92 pool: self,
93 conn: Some(conn),
94 _permit,
95 })
96 }
97
98 #[inline(never)]
99 fn connect(&self) -> BoxedFuture<'_, Result<PoolClient, Error>> {
100 Box::pin(async move {
101 let (client, mut driver) = Postgres::new(self.config.clone()).connect().await?;
102 tokio::task::spawn(async move {
103 while let Ok(Some(_)) = driver.try_next().await {
104 }
106 });
107 Ok(PoolClient::new(client))
108 })
109 }
110}
111
112pub struct PoolConnection<'a> {
129 pool: &'a Pool,
130 conn: Option<PoolClient>,
131 _permit: SemaphorePermit<'a>,
132}
133
134impl PoolConnection<'_> {
135 #[inline]
137 pub fn transaction(&mut self) -> impl Future<Output = Result<Transaction<Self>, Error>> + Send {
138 Transaction::<Self>::builder().begin(self)
139 }
140
141 #[inline]
143 pub fn copy_in(&mut self, stmt: &Statement) -> impl Future<Output = Result<CopyIn<Self>, Error>> + Send {
144 CopyIn::new(self, stmt)
145 }
146
147 #[inline]
149 pub async fn copy_out(&self, stmt: &Statement) -> Result<CopyOut, Error> {
150 CopyOut::new(self, stmt).await
151 }
152
153 #[inline(always)]
206 pub fn consume(self) -> Self {
207 self
208 }
209
210 pub fn cancel_token(&self) -> Session {
212 self.conn().client.cancel_token()
213 }
214
215 fn insert_cache(&mut self, named: &str, stmt: Statement) -> Arc<Statement> {
216 let stmt = Arc::new(stmt);
217 self.conn_mut().statements.insert(Box::from(named), stmt.clone());
218 stmt
219 }
220
221 fn conn(&self) -> &PoolClient {
222 self.conn.as_ref().unwrap()
223 }
224
225 fn conn_mut(&mut self) -> &mut PoolClient {
226 self.conn.as_mut().unwrap()
227 }
228}
229
230impl ClientBorrowMut for PoolConnection<'_> {
231 #[inline]
232 fn _borrow_mut(&mut self) -> &mut Client {
233 &mut self.conn_mut().client
234 }
235}
236
237impl Prepare for PoolConnection<'_> {
238 #[inline]
239 fn _get_type(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>> {
240 self.conn().client._get_type(oid)
241 }
242
243 #[inline]
244 fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
245 self.conn().client._get_type_blocking(oid)
246 }
247}
248
249impl Query for PoolConnection<'_> {
250 #[inline]
251 fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
252 where
253 S: Encode,
254 {
255 self.conn().client._send_encode_query(stmt)
256 }
257}
258
259impl r#Copy for PoolConnection<'_> {
260 #[inline]
261 fn send_one_way<F>(&self, func: F) -> Result<(), Error>
262 where
263 F: FnOnce(&mut BytesMut) -> Result<(), Error>,
264 {
265 self.conn().client.send_one_way(func)
266 }
267}
268
269impl Drop for PoolConnection<'_> {
270 fn drop(&mut self) {
271 let conn = self.conn.take().unwrap();
272 if conn.client.closed() {
273 return;
274 }
275 self.pool.conn.lock().unwrap().push_back(conn);
276 }
277}
278
279struct PoolClient {
280 client: Client,
281 statements: HashMap<Box<str>, Arc<Statement>>,
282}
283
284impl PoolClient {
285 fn new(client: Client) -> Self {
286 Self {
287 client,
288 statements: HashMap::new(),
289 }
290 }
291}
292
293impl<'c, 's> ExecuteMut<'c, PoolConnection<'_>> for StatementNamed<'s>
294where
295 's: 'c,
296{
297 type ExecuteMutOutput = StatementCacheFuture<'c>;
298 type QueryMutOutput = Self::ExecuteMutOutput;
299
300 fn execute_mut(self, cli: &'c mut PoolConnection) -> Self::ExecuteMutOutput {
301 match cli.conn().statements.get(self.stmt) {
302 Some(stmt) => StatementCacheFuture::Cached(stmt.clone()),
303 None => StatementCacheFuture::Prepared(Box::pin(async move {
304 let stmt = self.execute(cli).await?.leak();
305 Ok(cli.insert_cache(self.stmt, stmt))
306 })),
307 }
308 }
309
310 #[inline]
311 fn query_mut(self, cli: &'c mut PoolConnection) -> Self::QueryMutOutput {
312 self.execute_mut(cli)
313 }
314}
315
316pub enum StatementCacheFuture<'c> {
317 Cached(Arc<Statement>),
318 Prepared(BoxedFuture<'c, Result<Arc<Statement>, Error>>),
319 Done,
320}
321
322impl Future for StatementCacheFuture<'_> {
323 type Output = Result<Arc<Statement>, Error>;
324
325 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
326 let this = self.get_mut();
327 match mem::replace(this, Self::Done) {
328 Self::Cached(stmt) => Poll::Ready(Ok(stmt)),
329 Self::Prepared(mut fut) => {
330 let res = fut.as_mut().poll(cx);
331 if res.is_pending() {
332 drop(mem::replace(this, Self::Prepared(fut)));
333 }
334 res
335 }
336 Self::Done => panic!("StatementCacheFuture polled after finish"),
337 }
338 }
339}
340
341#[cfg(test)]
342mod test {
343 use super::*;
344
345 #[tokio::test]
346 async fn pool() {
347 let pool = Pool::builder("postgres://postgres:postgres@localhost:5432")
348 .build()
349 .unwrap();
350
351 let mut conn = pool.get().await.unwrap();
352
353 let stmt = Statement::named("SELECT 1", &[]).execute_mut(&mut conn).await.unwrap();
354 stmt.execute(&conn.consume()).await.unwrap();
355 }
356}