sqlx_core/mssql/
transaction.rs

1use std::borrow::Cow;
2
3use futures_core::future::BoxFuture;
4
5use crate::error::Error;
6use crate::executor::Executor;
7use crate::mssql::protocol::packet::PacketType;
8use crate::mssql::protocol::sql_batch::SqlBatch;
9use crate::mssql::{Mssql, MssqlConnection};
10use crate::transaction::TransactionManager;
11
12/// Implementation of [`TransactionManager`] for MSSQL.
13pub struct MssqlTransactionManager;
14
15impl TransactionManager for MssqlTransactionManager {
16    type Database = Mssql;
17
18    fn begin(conn: &mut MssqlConnection) -> BoxFuture<'_, Result<(), Error>> {
19        Box::pin(async move {
20            let depth = conn.stream.transaction_depth;
21
22            let query = if depth == 0 {
23                Cow::Borrowed("BEGIN TRAN ")
24            } else {
25                Cow::Owned(format!("SAVE TRAN _sqlx_savepoint_{}", depth))
26            };
27
28            conn.execute(&*query).await?;
29            conn.stream.transaction_depth = depth + 1;
30
31            Ok(())
32        })
33    }
34
35    fn commit(conn: &mut MssqlConnection) -> BoxFuture<'_, Result<(), Error>> {
36        Box::pin(async move {
37            let depth = conn.stream.transaction_depth;
38
39            if depth > 0 {
40                if depth == 1 {
41                    // savepoints are not released in MSSQL
42                    conn.execute("COMMIT TRAN").await?;
43                }
44
45                conn.stream.transaction_depth = depth - 1;
46            }
47
48            Ok(())
49        })
50    }
51
52    fn rollback(conn: &mut MssqlConnection) -> BoxFuture<'_, Result<(), Error>> {
53        Box::pin(async move {
54            let depth = conn.stream.transaction_depth;
55
56            if depth > 0 {
57                let query = if depth == 1 {
58                    Cow::Borrowed("ROLLBACK TRAN")
59                } else {
60                    Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
61                };
62
63                conn.execute(&*query).await?;
64                conn.stream.transaction_depth = depth - 1;
65            }
66
67            Ok(())
68        })
69    }
70
71    fn start_rollback(conn: &mut MssqlConnection) {
72        let depth = conn.stream.transaction_depth;
73
74        if depth > 0 {
75            let query = if depth == 1 {
76                Cow::Borrowed("ROLLBACK TRAN")
77            } else {
78                Cow::Owned(format!("ROLLBACK TRAN _sqlx_savepoint_{}", depth - 1))
79            };
80
81            conn.stream.pending_done_count += 1;
82
83            conn.stream.write_packet(
84                PacketType::SqlBatch,
85                SqlBatch {
86                    transaction_descriptor: conn.stream.transaction_descriptor,
87                    sql: &*query,
88                },
89            );
90
91            conn.stream.transaction_depth = depth - 1;
92        }
93    }
94}