1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
use std::future::{Future, IntoFuture};
use std::marker::PhantomData;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Instant;

use tracing::Instrument;

use crate::error::ErrorKind;
use crate::query::Query;
use crate::{Error, Sql};

impl<'a, Cols, T> IntoFuture for Sql<'a, Cols, T>
where
    T: Query<Cols> + Send + Sync + 'a,
    Cols: Send + Sync + 'a,
{
    type Output = Result<T, Error>;
    type IntoFuture = SqlFuture<'a, T>;

    fn into_future(self) -> Self::IntoFuture {
        SqlFuture::new(self)
    }
}

/// A future for executing an sql query.
pub struct SqlFuture<'a, T> {
    future: Pin<Box<dyn Future<Output = Result<T, Error>> + Send + 'a>>,
    marker: PhantomData<&'a ()>,
}

impl<'a, T> SqlFuture<'a, T> {
    pub(crate) fn new<Cols>(sql: Sql<'a, Cols, T>) -> Self
    where
        T: Query<Cols> + Send + Sync + 'a,
        Cols: Send + Sync + 'a,
    {
        let span =
            tracing::debug_span!("sql query", query = sql.query, parameters = ?sql.parameters);
        let start = Instant::now();

        SqlFuture {
            future: Box::pin(
                // Note: changes here must be applied to `with_connection` below too!
                async move {
                    let mut i = 1;
                    loop {
                        let conn = super::connect().await?;
                        match T::query(&sql, &conn).await {
                            Ok(r) => {
                                let elapsed = start.elapsed();
                                tracing::trace!(?elapsed, "sql query finished");
                                return Ok(r);
                            }
                            Err(Error {
                                kind: ErrorKind::Postgres(err),
                                ..
                            }) if err.is_closed() && i <= 5 => {
                                // retry pool size + 1 times if connection is closed (might have
                                // received a closed one from the connection pool)
                                i += 1;
                                tracing::trace!("retry due to connection closed error");
                                continue;
                            }
                            Err(err) => {
                                return Err(err);
                            }
                        }
                    }
                }
                .instrument(span),
            ),
            marker: PhantomData,
        }
    }

    pub(crate) fn with_connection<Cols>(
        sql: Sql<'a, Cols, T>,
        conn: impl super::Connection + 'a,
    ) -> Self
    where
        T: Query<Cols> + Send + Sync + 'a,
        Cols: Send + Sync + 'a,
    {
        let span =
            tracing::debug_span!("sql query", query = sql.query, parameters = ?sql.parameters);
        let start = Instant::now();

        SqlFuture {
            future: Box::pin(
                // Note: changes here must be applied to `bew` above too!
                async move {
                    let mut i = 1;
                    loop {
                        match T::query(&sql, &conn).await {
                            Ok(r) => {
                                let elapsed = start.elapsed();
                                tracing::trace!(?elapsed, "sql query finished");
                                return Ok(r);
                            }
                            Err(Error {
                                kind: ErrorKind::Postgres(err),
                                ..
                            }) if err.is_closed() && i <= 5 => {
                                // retry pool size + 1 times if connection is closed (might have
                                // received a closed one from the connection pool)
                                i += 1;
                                tracing::trace!("retry due to connection closed error");
                                continue;
                            }
                            Err(err) => {
                                return Err(err);
                            }
                        }
                    }
                }
                .instrument(span),
            ),
            marker: PhantomData,
        }
    }
}

impl<'a, T> Future for SqlFuture<'a, T> {
    type Output = Result<T, Error>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        self.future.as_mut().poll(cx)
    }
}