Skip to main content

sql_orm_tiberius/
transaction.rs

1use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::run_with_timeout;
3use crate::error::{TiberiusErrorContext, map_tiberius_error};
4use crate::executor::{
5    ExecuteResult, execute_compiled, fetch_all_compiled, fetch_one_compiled, query_raw_compiled,
6};
7use crate::telemetry::trace_transaction_command;
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::{FromRow, OrmError};
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::{Client, QueryStream};
13
14const BEGIN_TRANSACTION_SQL: &str = "BEGIN TRANSACTION";
15const COMMIT_TRANSACTION_SQL: &str = "COMMIT TRANSACTION";
16const ROLLBACK_TRANSACTION_SQL: &str = "ROLLBACK TRANSACTION";
17
18pub struct MssqlTransaction<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
19    client: &'a mut Client<S>,
20    query_timeout: Option<Duration>,
21    tracing_options: MssqlTracingOptions,
22    slow_query_options: MssqlSlowQueryOptions,
23    server_addr: String,
24    completed: bool,
25}
26
27impl<'a, S> MssqlTransaction<'a, S>
28where
29    S: AsyncRead + AsyncWrite + Unpin + Send,
30{
31    pub(crate) async fn begin(
32        client: &'a mut Client<S>,
33        query_timeout: Option<Duration>,
34        tracing_options: MssqlTracingOptions,
35        slow_query_options: MssqlSlowQueryOptions,
36        server_addr: String,
37    ) -> Result<Self, OrmError> {
38        begin_transaction_scope(client, query_timeout, tracing_options, &server_addr).await?;
39
40        Ok(Self {
41            client,
42            query_timeout,
43            tracing_options,
44            slow_query_options,
45            server_addr,
46            completed: false,
47        })
48    }
49
50    pub fn is_completed(&self) -> bool {
51        self.completed
52    }
53
54    pub async fn commit(mut self) -> Result<(), OrmError> {
55        self.finish(COMMIT_TRANSACTION_SQL).await
56    }
57
58    pub async fn rollback(mut self) -> Result<(), OrmError> {
59        self.finish(ROLLBACK_TRANSACTION_SQL).await
60    }
61
62    pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
63        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
64            execute_compiled(
65                self.client,
66                query,
67                self.tracing_options,
68                self.slow_query_options,
69                &self.server_addr,
70                self.query_timeout,
71            )
72            .await
73        })
74        .await
75    }
76
77    pub async fn query_raw<'b>(
78        &'b mut self,
79        query: CompiledQuery,
80    ) -> Result<QueryStream<'b>, OrmError> {
81        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
82            query_raw_compiled(
83                self.client,
84                query,
85                self.tracing_options,
86                self.slow_query_options,
87                &self.server_addr,
88                self.query_timeout,
89            )
90            .await
91        })
92        .await
93    }
94
95    pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
96    where
97        T: FromRow + Send,
98    {
99        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
100            fetch_one_compiled(
101                self.client,
102                query,
103                self.tracing_options,
104                self.slow_query_options,
105                crate::config::MssqlRetryOptions::disabled(),
106                &self.server_addr,
107                self.query_timeout,
108            )
109            .await
110        })
111        .await
112    }
113
114    pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
115    where
116        T: FromRow + Send,
117    {
118        run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
119            fetch_all_compiled(
120                self.client,
121                query,
122                self.tracing_options,
123                self.slow_query_options,
124                crate::config::MssqlRetryOptions::disabled(),
125                &self.server_addr,
126                self.query_timeout,
127            )
128            .await
129        })
130        .await
131    }
132
133    async fn finish(&mut self, sql: &'static str) -> Result<(), OrmError> {
134        if self.completed {
135            return Err(OrmError::new("transaction has already been completed"));
136        }
137
138        run_transaction_command(
139            self.client,
140            sql,
141            self.query_timeout,
142            self.tracing_options,
143            &self.server_addr,
144        )
145        .await?;
146        self.completed = true;
147
148        Ok(())
149    }
150}
151
152pub(crate) async fn begin_transaction_scope<S>(
153    client: &mut Client<S>,
154    query_timeout: Option<Duration>,
155    tracing_options: MssqlTracingOptions,
156    server_addr: &str,
157) -> Result<(), OrmError>
158where
159    S: AsyncRead + AsyncWrite + Unpin + Send,
160{
161    run_transaction_command(
162        client,
163        BEGIN_TRANSACTION_SQL,
164        query_timeout,
165        tracing_options,
166        server_addr,
167    )
168    .await
169}
170
171pub(crate) async fn commit_transaction_scope<S>(
172    client: &mut Client<S>,
173    query_timeout: Option<Duration>,
174    tracing_options: MssqlTracingOptions,
175    server_addr: &str,
176) -> Result<(), OrmError>
177where
178    S: AsyncRead + AsyncWrite + Unpin + Send,
179{
180    run_transaction_command(
181        client,
182        COMMIT_TRANSACTION_SQL,
183        query_timeout,
184        tracing_options,
185        server_addr,
186    )
187    .await
188}
189
190pub(crate) async fn rollback_transaction_scope<S>(
191    client: &mut Client<S>,
192    query_timeout: Option<Duration>,
193    tracing_options: MssqlTracingOptions,
194    server_addr: &str,
195) -> Result<(), OrmError>
196where
197    S: AsyncRead + AsyncWrite + Unpin + Send,
198{
199    run_transaction_command(
200        client,
201        ROLLBACK_TRANSACTION_SQL,
202        query_timeout,
203        tracing_options,
204        server_addr,
205    )
206    .await
207}
208
209pub(crate) async fn run_transaction_command<S>(
210    client: &mut Client<S>,
211    sql: &'static str,
212    query_timeout: Option<Duration>,
213    tracing_options: MssqlTracingOptions,
214    server_addr: &str,
215) -> Result<(), OrmError>
216where
217    S: AsyncRead + AsyncWrite + Unpin + Send,
218{
219    trace_transaction_command(tracing_options, server_addr, query_timeout, sql, async {
220        run_with_timeout(query_timeout, "SQL Server query timed out", async {
221            client
222                .simple_query(sql)
223                .await
224                .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?
225                .into_results()
226                .await
227                .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?;
228
229            Ok(())
230        })
231        .await
232    })
233    .await
234}
235
236#[cfg(test)]
237mod tests {
238    use super::{
239        BEGIN_TRANSACTION_SQL, COMMIT_TRANSACTION_SQL, MssqlTransaction, ROLLBACK_TRANSACTION_SQL,
240        begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
241    };
242    use std::time::Duration;
243
244    #[test]
245    fn transaction_command_constants_match_expected_sql() {
246        assert_eq!(BEGIN_TRANSACTION_SQL, "BEGIN TRANSACTION");
247        assert_eq!(COMMIT_TRANSACTION_SQL, "COMMIT TRANSACTION");
248        assert_eq!(ROLLBACK_TRANSACTION_SQL, "ROLLBACK TRANSACTION");
249    }
250
251    #[test]
252    fn transaction_wrapper_tracks_completion_state() {
253        let wrapper = core::mem::size_of::<
254            Option<MssqlTransaction<'static, tokio_util::compat::Compat<tokio::net::TcpStream>>>,
255        >();
256
257        assert!(wrapper > 0);
258    }
259
260    #[test]
261    fn exposes_scope_level_transaction_helpers() {
262        let begin = begin_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
263        let commit = commit_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
264        let rollback =
265            rollback_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
266
267        let _ = (begin, commit, rollback);
268    }
269
270    #[tokio::test]
271    async fn transaction_timeout_shape_is_copyable_for_runtime_use() {
272        let timeout = Some(Duration::from_secs(1));
273
274        assert_eq!(timeout, Some(Duration::from_secs(1)));
275    }
276}