1use crate::config::{MssqlSlowQueryOptions, MssqlTracingOptions};
2use crate::connection::run_with_timeout;
3use crate::error::{TiberiusErrorContext, map_tiberius_error};
4use crate::executor::{
5 ExecuteResult, QueryExecutionOptions, execute_compiled, fetch_all_compiled, fetch_one_compiled,
6 query_raw_compiled,
7};
8use crate::telemetry::trace_transaction_command;
9use futures_io::{AsyncRead, AsyncWrite};
10use sql_orm_core::{FromRow, OrmError};
11use sql_orm_query::CompiledQuery;
12use std::time::Duration;
13use tiberius::{Client, QueryStream};
14
15const BEGIN_TRANSACTION_SQL: &str = "BEGIN TRANSACTION";
16const COMMIT_TRANSACTION_SQL: &str = "COMMIT TRANSACTION";
17const ROLLBACK_TRANSACTION_SQL: &str = "ROLLBACK TRANSACTION";
18
19pub struct MssqlTransaction<'a, S: AsyncRead + AsyncWrite + Unpin + Send> {
20 client: &'a mut Client<S>,
21 query_timeout: Option<Duration>,
22 tracing_options: MssqlTracingOptions,
23 slow_query_options: MssqlSlowQueryOptions,
24 server_addr: String,
25 completed: bool,
26}
27
28impl<'a, S> MssqlTransaction<'a, S>
29where
30 S: AsyncRead + AsyncWrite + Unpin + Send,
31{
32 pub(crate) async fn begin(
33 client: &'a mut Client<S>,
34 query_timeout: Option<Duration>,
35 tracing_options: MssqlTracingOptions,
36 slow_query_options: MssqlSlowQueryOptions,
37 server_addr: String,
38 ) -> Result<Self, OrmError> {
39 begin_transaction_scope(client, query_timeout, tracing_options, &server_addr).await?;
40
41 Ok(Self {
42 client,
43 query_timeout,
44 tracing_options,
45 slow_query_options,
46 server_addr,
47 completed: false,
48 })
49 }
50
51 pub fn is_completed(&self) -> bool {
52 self.completed
53 }
54
55 pub async fn commit(mut self) -> Result<(), OrmError> {
56 self.finish(COMMIT_TRANSACTION_SQL).await
57 }
58
59 pub async fn rollback(mut self) -> Result<(), OrmError> {
60 self.finish(ROLLBACK_TRANSACTION_SQL).await
61 }
62
63 pub async fn execute(&mut self, query: CompiledQuery) -> Result<ExecuteResult, OrmError> {
64 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
65 execute_compiled(
66 self.client,
67 query,
68 self.tracing_options,
69 self.slow_query_options,
70 &self.server_addr,
71 self.query_timeout,
72 )
73 .await
74 })
75 .await
76 }
77
78 pub async fn query_raw<'b>(
79 &'b mut self,
80 query: CompiledQuery,
81 ) -> Result<QueryStream<'b>, OrmError> {
82 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
83 query_raw_compiled(
84 self.client,
85 query,
86 self.tracing_options,
87 self.slow_query_options,
88 &self.server_addr,
89 self.query_timeout,
90 )
91 .await
92 })
93 .await
94 }
95
96 pub async fn fetch_one<T>(&mut self, query: CompiledQuery) -> Result<Option<T>, OrmError>
97 where
98 T: FromRow + Send,
99 {
100 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
101 fetch_one_compiled(
102 self.client,
103 query,
104 QueryExecutionOptions {
105 tracing: self.tracing_options,
106 slow_query: self.slow_query_options,
107 retry: crate::config::MssqlRetryOptions::disabled(),
108 server_addr: &self.server_addr,
109 timeout: self.query_timeout,
110 },
111 )
112 .await
113 })
114 .await
115 }
116
117 pub async fn fetch_all<T>(&mut self, query: CompiledQuery) -> Result<Vec<T>, OrmError>
118 where
119 T: FromRow + Send,
120 {
121 run_with_timeout(self.query_timeout, "SQL Server query timed out", async {
122 fetch_all_compiled(
123 self.client,
124 query,
125 QueryExecutionOptions {
126 tracing: self.tracing_options,
127 slow_query: self.slow_query_options,
128 retry: crate::config::MssqlRetryOptions::disabled(),
129 server_addr: &self.server_addr,
130 timeout: self.query_timeout,
131 },
132 )
133 .await
134 })
135 .await
136 }
137
138 async fn finish(&mut self, sql: &'static str) -> Result<(), OrmError> {
139 if self.completed {
140 return Err(OrmError::transaction(
141 "transaction has already been completed",
142 ));
143 }
144
145 run_transaction_command(
146 self.client,
147 sql,
148 self.query_timeout,
149 self.tracing_options,
150 &self.server_addr,
151 )
152 .await?;
153 self.completed = true;
154
155 Ok(())
156 }
157}
158
159pub(crate) async fn begin_transaction_scope<S>(
160 client: &mut Client<S>,
161 query_timeout: Option<Duration>,
162 tracing_options: MssqlTracingOptions,
163 server_addr: &str,
164) -> Result<(), OrmError>
165where
166 S: AsyncRead + AsyncWrite + Unpin + Send,
167{
168 run_transaction_command(
169 client,
170 BEGIN_TRANSACTION_SQL,
171 query_timeout,
172 tracing_options,
173 server_addr,
174 )
175 .await
176}
177
178pub(crate) async fn commit_transaction_scope<S>(
179 client: &mut Client<S>,
180 query_timeout: Option<Duration>,
181 tracing_options: MssqlTracingOptions,
182 server_addr: &str,
183) -> Result<(), OrmError>
184where
185 S: AsyncRead + AsyncWrite + Unpin + Send,
186{
187 run_transaction_command(
188 client,
189 COMMIT_TRANSACTION_SQL,
190 query_timeout,
191 tracing_options,
192 server_addr,
193 )
194 .await
195}
196
197pub(crate) async fn rollback_transaction_scope<S>(
198 client: &mut Client<S>,
199 query_timeout: Option<Duration>,
200 tracing_options: MssqlTracingOptions,
201 server_addr: &str,
202) -> Result<(), OrmError>
203where
204 S: AsyncRead + AsyncWrite + Unpin + Send,
205{
206 run_transaction_command(
207 client,
208 ROLLBACK_TRANSACTION_SQL,
209 query_timeout,
210 tracing_options,
211 server_addr,
212 )
213 .await
214}
215
216pub(crate) async fn run_transaction_command<S>(
217 client: &mut Client<S>,
218 sql: &'static str,
219 query_timeout: Option<Duration>,
220 tracing_options: MssqlTracingOptions,
221 server_addr: &str,
222) -> Result<(), OrmError>
223where
224 S: AsyncRead + AsyncWrite + Unpin + Send,
225{
226 trace_transaction_command(tracing_options, server_addr, query_timeout, sql, async {
227 run_with_timeout(query_timeout, "SQL Server query timed out", async {
228 client
229 .simple_query(sql)
230 .await
231 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?
232 .into_results()
233 .await
234 .map_err(|error| map_tiberius_error(&error, TiberiusErrorContext::ExecuteQuery))?;
235
236 Ok(())
237 })
238 .await
239 })
240 .await
241}
242
243#[cfg(test)]
244mod tests {
245 use super::{
246 BEGIN_TRANSACTION_SQL, COMMIT_TRANSACTION_SQL, MssqlTransaction, ROLLBACK_TRANSACTION_SQL,
247 begin_transaction_scope, commit_transaction_scope, rollback_transaction_scope,
248 };
249 use std::time::Duration;
250
251 #[test]
252 fn transaction_command_constants_match_expected_sql() {
253 assert_eq!(BEGIN_TRANSACTION_SQL, "BEGIN TRANSACTION");
254 assert_eq!(COMMIT_TRANSACTION_SQL, "COMMIT TRANSACTION");
255 assert_eq!(ROLLBACK_TRANSACTION_SQL, "ROLLBACK TRANSACTION");
256 }
257
258 #[test]
259 fn transaction_wrapper_tracks_completion_state() {
260 let wrapper = core::mem::size_of::<
261 Option<MssqlTransaction<'static, tokio_util::compat::Compat<tokio::net::TcpStream>>>,
262 >();
263
264 assert!(wrapper > 0);
265 }
266
267 #[test]
268 fn exposes_scope_level_transaction_helpers() {
269 let begin = begin_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
270 let commit = commit_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
271 let rollback =
272 rollback_transaction_scope::<tokio_util::compat::Compat<tokio::net::TcpStream>>;
273
274 let _ = (begin, commit, rollback);
275 }
276
277 #[tokio::test]
278 async fn transaction_timeout_shape_is_copyable_for_runtime_use() {
279 let timeout = Some(Duration::from_secs(1));
280
281 assert_eq!(timeout, Some(Duration::from_secs(1)));
282 }
283}