sea_orm/database/stream/
transaction.rs

1#![allow(missing_docs)]
2
3use std::{ops::DerefMut, pin::Pin, task::Poll};
4use tracing::instrument;
5
6use futures_util::Stream;
7#[cfg(feature = "sqlx-dep")]
8use futures_util::TryStreamExt;
9
10use futures_util::lock::MutexGuard;
11
12#[cfg(feature = "sqlx-dep")]
13use sqlx::Executor;
14
15use super::metric::MetricStream;
16#[cfg(feature = "sqlx-dep")]
17use crate::driver::*;
18use crate::{DbErr, InnerConnection, QueryResult, Statement};
19
20/// `TransactionStream` cannot be used in a `transaction` closure as it does not impl `Send`.
21/// It seems to be a Rust limitation right now, and solution to work around this deemed to be extremely hard.
22#[ouroboros::self_referencing]
23pub struct TransactionStream<'a> {
24    stmt: Statement,
25    conn: MutexGuard<'a, InnerConnection>,
26    metric_callback: Option<crate::metric::Callback>,
27    #[borrows(mut conn, stmt, metric_callback)]
28    #[not_covariant]
29    stream: MetricStream<'this>,
30}
31
32impl std::fmt::Debug for TransactionStream<'_> {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "TransactionStream")
35    }
36}
37
38impl TransactionStream<'_> {
39    #[instrument(level = "trace", skip(metric_callback))]
40    #[allow(unused_variables)]
41    pub(crate) fn build(
42        conn: MutexGuard<'_, InnerConnection>,
43        stmt: Statement,
44        metric_callback: Option<crate::metric::Callback>,
45    ) -> TransactionStream<'_> {
46        TransactionStreamBuilder {
47            stmt,
48            conn,
49            metric_callback,
50            stream_builder: |conn, stmt, _metric_callback| match conn.deref_mut() {
51                #[cfg(feature = "sqlx-mysql")]
52                InnerConnection::MySql(c) => {
53                    let query = crate::driver::sqlx_mysql::sqlx_query(stmt);
54                    let start = _metric_callback.is_some().then(std::time::SystemTime::now);
55                    let stream = c
56                        .fetch(query)
57                        .map_ok(Into::into)
58                        .map_err(sqlx_error_to_query_err);
59                    let elapsed = start.map(|s| s.elapsed().unwrap_or_default());
60                    MetricStream::new(_metric_callback, stmt, elapsed, stream)
61                }
62                #[cfg(feature = "sqlx-postgres")]
63                InnerConnection::Postgres(c) => {
64                    let query = crate::driver::sqlx_postgres::sqlx_query(stmt);
65                    let start = _metric_callback.is_some().then(std::time::SystemTime::now);
66                    let stream = c
67                        .fetch(query)
68                        .map_ok(Into::into)
69                        .map_err(sqlx_error_to_query_err);
70                    let elapsed = start.map(|s| s.elapsed().unwrap_or_default());
71                    MetricStream::new(_metric_callback, stmt, elapsed, stream)
72                }
73                #[cfg(feature = "sqlx-sqlite")]
74                InnerConnection::Sqlite(c) => {
75                    let query = crate::driver::sqlx_sqlite::sqlx_query(stmt);
76                    let start = _metric_callback.is_some().then(std::time::SystemTime::now);
77                    let stream = c
78                        .fetch(query)
79                        .map_ok(Into::into)
80                        .map_err(sqlx_error_to_query_err);
81                    let elapsed = start.map(|s| s.elapsed().unwrap_or_default());
82                    MetricStream::new(_metric_callback, stmt, elapsed, stream)
83                }
84                #[cfg(feature = "mock")]
85                InnerConnection::Mock(c) => {
86                    let start = _metric_callback.is_some().then(std::time::SystemTime::now);
87                    let stream = c.fetch(stmt);
88                    let elapsed = start.map(|s| s.elapsed().unwrap_or_default());
89                    MetricStream::new(_metric_callback, stmt, elapsed, stream)
90                }
91                #[cfg(feature = "proxy")]
92                InnerConnection::Proxy(c) => {
93                    let start = _metric_callback.is_some().then(std::time::SystemTime::now);
94                    let stream = futures_util::stream::once(async {
95                        Err(DbErr::BackendNotSupported {
96                            db: "Proxy",
97                            ctx: "TransactionStream",
98                        })
99                    });
100                    let elapsed = start.map(|s| s.elapsed().unwrap_or_default());
101                    MetricStream::new(_metric_callback, stmt, elapsed, stream)
102                }
103                #[allow(unreachable_patterns)]
104                _ => unreachable!(),
105            },
106        }
107        .build()
108    }
109}
110
111#[cfg(not(feature = "sync"))]
112impl Stream for TransactionStream<'_> {
113    type Item = Result<QueryResult, DbErr>;
114
115    fn poll_next(
116        self: Pin<&mut Self>,
117        cx: &mut std::task::Context<'_>,
118    ) -> Poll<Option<Self::Item>> {
119        let this = self.get_mut();
120        this.with_stream_mut(|stream| Pin::new(stream).poll_next(cx))
121    }
122}
123
124#[cfg(feature = "sync")]
125impl Iterator for TransactionStream<'_> {
126    type Item = Result<QueryResult, DbErr>;
127
128    fn next(&mut self) -> Option<Self::Item> {
129        self.with_stream_mut(|stream| stream.next())
130    }
131}