1use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::run_with_timeout;
3use crate::error::{TiberiusErrorContext, map_tiberius_error};
4use crate::executor::{
5 ExecuteResult, execute_compiled, fetch_all_compiled, fetch_one_compiled, query_raw_compiled,
6};
7use crate::telemetry::trace_transaction_command;
8use futures_io::{AsyncRead, AsyncWrite};
9use sql_orm_core::{FromRow, OrmError};
10use sql_orm_query::CompiledQuery;
11use std::time::Duration;
12use tiberius::{Client, QueryStream};
13
14const BEGIN_TRANSACTION_SQL: &str = "BEGIN TRANSACTION";
15const COMMIT_TRANSACTION_SQL: &str = "COMMIT TRANSACTION";
16const ROLLBACK_TRANSACTION_SQL: &str = "ROLLBACK TRANSACTION";
17
18pub struct MssqlTransaction<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
19 client: &'a mut Client<S>,
20 query_timeout: Option<Duration>,
21 tracing_options: MssqlTracingOptions,
22 slow_query_options: MssqlSlowQueryOptions,
23 server_addr: String,
24 completed: bool,
25}
26
27impl<'a, S> MssqlTransaction<'a, S>
28where
29 S: AsyncRead + AsyncWrite + Unpin + Send,
30{
31 pub(crate) async fn begin(
32 client: &'a mut Client<S>,
33 query_timeout: Option<Duration>,
34 tracing_options: MssqlTracingOptions,
35 slow_query_options: MssqlSlowQueryOptions,
36 server_addr: String,
37 ) -> Result<Self, OrmError> {
38 begin_transaction_scope(client, query_timeout, tracing_options, &server_addr).await?;
39
40 Ok(Self {
41 client,
42 query_timeout,
43 tracing_options,
44 slow_query_options,
45 server_addr,
46 completed: false,
47 })
48 }
49
50 pub fn is_completed(&self) -> bool {
51 self.completed
52 }
53
54 pub async fn commit(mut self) -> Result<(), OrmError> {
55 self.finish(COMMIT_TRANSACTION_SQL).await
56 }
57
58 pub async fn rollback(mut self) -> Result<(), OrmError> {
59 self.finish(ROLLBACK_TRANSACTION_SQL).await
60 }
61
62 pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
63 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
64 execute_compiled(
65 self.client,
66 query,
67 self.tracing_options,
68 self.slow_query_options,
69 &self.server_addr,
70 self.query_timeout,
71 )
72 .await
73 })
74 .await
75 }
76
77 pub async fn query_raw<'b>(
78 &'b mut self,
79 query: CompiledQuery,
80 ) -> Result<QueryStream<'b>, OrmError> {
81 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
82 query_raw_compiled(
83 self.client,
84 query,
85 self.tracing_options,
86 self.slow_query_options,
87 &self.server_addr,
88 self.query_timeout,
89 )
90 .await
91 })
92 .await
93 }
94
95 pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
96 where
97 T: FromRow + Send,
98 {
99 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
100 fetch_one_compiled(
101 self.client,
102 query,
103 self.tracing_options,
104 self.slow_query_options,
105 crate::config::MssqlRetryOptions::disabled(),
106 &self.server_addr,
107 self.query_timeout,
108 )
109 .await
110 })
111 .await
112 }
113
114 pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
115 where
116 T: FromRow + Send,
117 {
118 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
119 fetch_all_compiled(
120 self.client,
121 query,
122 self.tracing_options,
123 self.slow_query_options,
124 crate::config::MssqlRetryOptions::disabled(),
125 &self.server_addr,
126 self.query_timeout,
127 )
128 .await
129 })
130 .await
131 }
132
133 async fn finish(&mut self, sql: &'static str) -> Result<(), OrmError> {
134 if self.completed {
135 return Err(OrmError::new("transaction has already been completed"));
136 }
137
138 run_transaction_command(
139 self.client,
140 sql,
141 self.query_timeout,
142 self.tracing_options,
143 &self.server_addr,
144 )
145 .await?;
146 self.completed = true;
147
148 Ok(())
149 }
150}
151
152pub(crate) async fn begin_transaction_scope<S>(
153 client: &mut Client<S>,
154 query_timeout: Option<Duration>,
155 tracing_options: MssqlTracingOptions,
156 server_addr: &str,
157) -> Result<(), OrmError>
158where
159 S: AsyncRead + AsyncWrite + Unpin + Send,
160{
161 run_transaction_command(
162 client,
163 BEGIN_TRANSACTION_SQL,
164 query_timeout,
165 tracing_options,
166 server_addr,
167 )
168 .await
169}
170
171pub(crate) async fn commit_transaction_scope<S>(
172 client: &mut Client<S>,
173 query_timeout: Option<Duration>,
174 tracing_options: MssqlTracingOptions,
175 server_addr: &str,
176) -> Result<(), OrmError>
177where
178 S: AsyncRead + AsyncWrite + Unpin + Send,
179{
180 run_transaction_command(
181 client,
182 COMMIT_TRANSACTION_SQL,
183 query_timeout,
184 tracing_options,
185 server_addr,
186 )
187 .await
188}
189
190pub(crate) async fn rollback_transaction_scope<S>(
191 client: &mut Client<S>,
192 query_timeout: Option<Duration>,
193 tracing_options: MssqlTracingOptions,
194 server_addr: &str,
195) -> Result<(), OrmError>
196where
197 S: AsyncRead + AsyncWrite + Unpin + Send,
198{
199 run_transaction_command(
200 client,
201 ROLLBACK_TRANSACTION_SQL,
202 query_timeout,
203 tracing_options,
204 server_addr,
205 )
206 .await
207}
208
209pub(crate) async fn run_transaction_command<S>(
210 client: &mut Client<S>,
211 sql: &'static str,
212 query_timeout: Option<Duration>,
213 tracing_options: MssqlTracingOptions,
214 server_addr: &str,
215) -> Result<(), OrmError>
216where
217 S: AsyncRead + AsyncWrite + Unpin + Send,
218{
219 trace_transaction_command(tracing_options, server_addr, query_timeout, sql, async {
220 run_with_timeout(query_timeout, "SQL Server query timed out", async {
221 client
222 .simple_query(sql)
223 .await
224 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?
225 .into_results()
226 .await
227 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?;
228
229 Ok(())
230 })
231 .await
232 })
233 .await
234}
235
236#[cfg(test)]
237mod tests {
238 use super::{
239 BEGIN_TRANSACTION_SQL, COMMIT_TRANSACTION_SQL, MssqlTransaction, ROLLBACK_TRANSACTION_SQL,
240 begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
241 };
242 use std::time::Duration;
243
244 #[test]
245 fn transaction_command_constants_match_expected_sql() {
246 assert_eq!(BEGIN_TRANSACTION_SQL, "BEGIN TRANSACTION");
247 assert_eq!(COMMIT_TRANSACTION_SQL, "COMMIT TRANSACTION");
248 assert_eq!(ROLLBACK_TRANSACTION_SQL, "ROLLBACK TRANSACTION");
249 }
250
251 #[test]
252 fn transaction_wrapper_tracks_completion_state() {
253 let wrapper = core::mem::size_of::<
254 Option<MssqlTransaction<'static, tokio_util::compat::Compat<tokio::net::TcpStream>>>,
255 >();
256
257 assert!(wrapper > 0);
258 }
259
260 #[test]
261 fn exposes_scope_level_transaction_helpers() {
262 let begin = begin_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
263 let commit = commit_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
264 let rollback =
265 rollback_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
266
267 let _ = (begin, commit, rollback);
268 }
269
270 #[tokio::test]
271 async fn transaction_timeout_shape_is_copyable_for_runtime_use() {
272 let timeout = Some(Duration::from_secs(1));
273
274 assert_eq!(timeout, Some(Duration::from_secs(1)));
275 }
276}