Skip to main content

sql_orm_tiberius/
transaction.rs

1use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::run_with_timeout;
3use crate::error::{TiberiusErrorContext, map_tiberius_error};
4use crate::executor::{
5    ExecuteResult, QueryExecutionOptions, execute_compiled, fetch_all_compiled, fetch_one_compiled,
6    query_raw_compiled,
7};
8use crate::telemetry::trace_transaction_command;
9use futures_io::{AsyncRead, AsyncWrite};
10use sql_orm_core::{FromRow, OrmError};
11use sql_orm_query::CompiledQuery;
12use std::time::Duration;
13use tiberius::{Client, QueryStream};
14
15const BEGIN_TRANSACTION_SQL: &str = "BEGIN TRANSACTION";
16const COMMIT_TRANSACTION_SQL: &str = "COMMIT TRANSACTION";
17const ROLLBACK_TRANSACTION_SQL: &str = "ROLLBACK TRANSACTION";
18
19pub struct MssqlTransaction<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
20    client: &'a mut Client<S>,
21    query_timeout: Option<Duration>,
22    tracing_options: MssqlTracingOptions,
23    slow_query_options: MssqlSlowQueryOptions,
24    server_addr: String,
25    completed: bool,
26}
27
28impl<'a, S> MssqlTransaction<'a, S>
29where
30    S: AsyncRead + AsyncWrite + Unpin + Send,
31{
32    pub(crate) async fn begin(
33        client: &'a mut Client<S>,
34        query_timeout: Option<Duration>,
35        tracing_options: MssqlTracingOptions,
36        slow_query_options: MssqlSlowQueryOptions,
37        server_addr: String,
38    ) -> Result<Self, OrmError> {
39        begin_transaction_scope(client, query_timeout, tracing_options, &server_addr).await?;
40
41        Ok(Self {
42            client,
43            query_timeout,
44            tracing_options,
45            slow_query_options,
46            server_addr,
47            completed: false,
48        })
49    }
50
51    pub fn is_completed(&self) -> bool {
52        self.completed
53    }
54
55    pub async fn commit(mut self) -> Result<(), OrmError> {
56        self.finish(COMMIT_TRANSACTION_SQL).await
57    }
58
59    pub async fn rollback(mut self) -> Result<(), OrmError> {
60        self.finish(ROLLBACK_TRANSACTION_SQL).await
61    }
62
63    pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
64        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
65            execute_compiled(
66                self.client,
67                query,
68                self.tracing_options,
69                self.slow_query_options,
70                &self.server_addr,
71                self.query_timeout,
72            )
73            .await
74        })
75        .await
76    }
77
78    pub async fn query_raw<'b>(
79        &'b mut self,
80        query: CompiledQuery,
81    ) -> Result<QueryStream<'b>, OrmError> {
82        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
83            query_raw_compiled(
84                self.client,
85                query,
86                self.tracing_options,
87                self.slow_query_options,
88                &self.server_addr,
89                self.query_timeout,
90            )
91            .await
92        })
93        .await
94    }
95
96    pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
97    where
98        T: FromRow + Send,
99    {
100        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
101            fetch_one_compiled(
102                self.client,
103                query,
104                QueryExecutionOptions {
105                    tracing: self.tracing_options,
106                    slow_query: self.slow_query_options,
107                    retry: crate::config::MssqlRetryOptions::disabled(),
108                    server_addr: &self.server_addr,
109                    timeout: self.query_timeout,
110                },
111            )
112            .await
113        })
114        .await
115    }
116
117    pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
118    where
119        T: FromRow + Send,
120    {
121        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
122            fetch_all_compiled(
123                self.client,
124                query,
125                QueryExecutionOptions {
126                    tracing: self.tracing_options,
127                    slow_query: self.slow_query_options,
128                    retry: crate::config::MssqlRetryOptions::disabled(),
129                    server_addr: &self.server_addr,
130                    timeout: self.query_timeout,
131                },
132            )
133            .await
134        })
135        .await
136    }
137
138    async fn finish(&mut self, sql: &'static str) -> Result<(), OrmError> {
139        if self.completed {
140            return Err(OrmError::transaction(
141                "transaction has already been completed",
142            ));
143        }
144
145        run_transaction_command(
146            self.client,
147            sql,
148            self.query_timeout,
149            self.tracing_options,
150            &self.server_addr,
151        )
152        .await?;
153        self.completed = true;
154
155        Ok(())
156    }
157}
158
159pub(crate) async fn begin_transaction_scope<S>(
160    client: &mut Client<S>,
161    query_timeout: Option<Duration>,
162    tracing_options: MssqlTracingOptions,
163    server_addr: &str,
164) -> Result<(), OrmError>
165where
166    S: AsyncRead + AsyncWrite + Unpin + Send,
167{
168    run_transaction_command(
169        client,
170        BEGIN_TRANSACTION_SQL,
171        query_timeout,
172        tracing_options,
173        server_addr,
174    )
175    .await
176}
177
178pub(crate) async fn commit_transaction_scope<S>(
179    client: &mut Client<S>,
180    query_timeout: Option<Duration>,
181    tracing_options: MssqlTracingOptions,
182    server_addr: &str,
183) -> Result<(), OrmError>
184where
185    S: AsyncRead + AsyncWrite + Unpin + Send,
186{
187    run_transaction_command(
188        client,
189        COMMIT_TRANSACTION_SQL,
190        query_timeout,
191        tracing_options,
192        server_addr,
193    )
194    .await
195}
196
197pub(crate) async fn rollback_transaction_scope<S>(
198    client: &mut Client<S>,
199    query_timeout: Option<Duration>,
200    tracing_options: MssqlTracingOptions,
201    server_addr: &str,
202) -> Result<(), OrmError>
203where
204    S: AsyncRead + AsyncWrite + Unpin + Send,
205{
206    run_transaction_command(
207        client,
208        ROLLBACK_TRANSACTION_SQL,
209        query_timeout,
210        tracing_options,
211        server_addr,
212    )
213    .await
214}
215
216pub(crate) async fn run_transaction_command<S>(
217    client: &mut Client<S>,
218    sql: &'static str,
219    query_timeout: Option<Duration>,
220    tracing_options: MssqlTracingOptions,
221    server_addr: &str,
222) -> Result<(), OrmError>
223where
224    S: AsyncRead + AsyncWrite + Unpin + Send,
225{
226    trace_transaction_command(tracing_options, server_addr, query_timeout, sql, async {
227        run_with_timeout(query_timeout, "SQL Server query timed out", async {
228            client
229                .simple_query(sql)
230                .await
231                .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?
232                .into_results()
233                .await
234                .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?;
235
236            Ok(())
237        })
238        .await
239    })
240    .await
241}
242
243#[cfg(test)]
244mod tests {
245    use super::{
246        BEGIN_TRANSACTION_SQL, COMMIT_TRANSACTION_SQL, MssqlTransaction, ROLLBACK_TRANSACTION_SQL,
247        begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
248    };
249    use std::time::Duration;
250
251    #[test]
252    fn transaction_command_constants_match_expected_sql() {
253        assert_eq!(BEGIN_TRANSACTION_SQL, "BEGIN TRANSACTION");
254        assert_eq!(COMMIT_TRANSACTION_SQL, "COMMIT TRANSACTION");
255        assert_eq!(ROLLBACK_TRANSACTION_SQL, "ROLLBACK TRANSACTION");
256    }
257
258    #[test]
259    fn transaction_wrapper_tracks_completion_state() {
260        let wrapper = core::mem::size_of::<
261            Option<MssqlTransaction<'static, tokio_util::compat::Compat<tokio::net::TcpStream>>>,
262        >();
263
264        assert!(wrapper > 0);
265    }
266
267    #[test]
268    fn exposes_scope_level_transaction_helpers() {
269        let begin = begin_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
270        let commit = commit_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
271        let rollback =
272            rollback_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
273
274        let _ = (begin, commit, rollback);
275    }
276
277    #[tokio::test]
278    async fn transaction_timeout_shape_is_copyable_for_runtime_use() {
279        let timeout = Some(Duration::from_secs(1));
280
281        assert_eq!(timeout, Some(Duration::from_secs(1)));
282    }
283}