tds_protocol/sql_batch.rs
1//! SQL batch request encoding.
2//!
3//! This module provides encoding for SQL batch requests (packet type 0x01).
4//! Per MS-TDS spec, a SQL batch contains:
5//! - ALL_HEADERS section (required for TDS 7.2+)
6//! - SQL text encoded as UTF-16LE
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::codec::write_utf16_string;
11
12/// Encode a SQL batch request with auto-commit (no explicit transaction).
13///
14/// The SQL batch packet payload includes:
15/// 1. ALL_HEADERS section (required for TDS 7.2+)
16/// 2. SQL text encoded as UTF-16LE
17///
18/// This function returns the encoded payload (without the packet header).
19/// For requests within an explicit transaction, use [`encode_sql_batch_with_transaction`].
20///
21/// # Example
22///
23/// ```
24/// use tds_protocol::sql_batch::encode_sql_batch;
25///
26/// let sql = "SELECT * FROM users WHERE id = 1";
27/// let payload = encode_sql_batch(sql);
28///
29/// // Payload includes ALL_HEADERS + UTF-16LE encoded SQL
30/// assert!(!payload.is_empty());
31/// ```
32#[must_use]
33pub fn encode_sql_batch(sql: &str) -> Bytes {
34 encode_sql_batch_with_transaction(sql, 0)
35}
36
37/// Encode a SQL batch request with a transaction descriptor.
38///
39/// Per MS-TDS spec, when executing within an explicit transaction:
40/// - The `transaction_descriptor` MUST be the value returned by the server
41/// in the BeginTransaction EnvChange token.
42/// - For auto-commit mode (no explicit transaction), use 0.
43///
44/// # Arguments
45///
46/// * `sql` - The SQL text to execute
47/// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
48/// or 0 for auto-commit mode.
49///
50/// # Example
51///
52/// ```
53/// use tds_protocol::sql_batch::encode_sql_batch_with_transaction;
54///
55/// // Within a transaction with descriptor 0x1234567890ABCDEF
56/// let sql = "INSERT INTO users VALUES (1, 'Alice')";
57/// let tx_descriptor = 0x1234567890ABCDEF_u64;
58/// let payload = encode_sql_batch_with_transaction(sql, tx_descriptor);
59/// ```
60#[must_use]
61pub fn encode_sql_batch_with_transaction(sql: &str, transaction_descriptor: u64) -> Bytes {
62 // Capacity: ALL_HEADERS (22 bytes) + SQL UTF-16LE (sql.len() * 2)
63 let mut buf = BytesMut::with_capacity(22 + sql.len() * 2);
64
65 // ALL_HEADERS section (required for TDS 7.2+)
66 // Per MS-TDS spec: ALL_HEADERS = TotalLength + Headers
67 let all_headers_start = buf.len();
68 buf.put_u32_le(0); // Total length placeholder
69
70 // Transaction descriptor header (type 0x0002)
71 // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
72 buf.put_u32_le(18); // Header length = 18 bytes
73 buf.put_u16_le(0x0002); // Header type: transaction descriptor
74 buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
75 buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
76
77 // Fill in ALL_HEADERS total length
78 let all_headers_len = buf.len() - all_headers_start;
79 let len_bytes = (all_headers_len as u32).to_le_bytes();
80 buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
81
82 // SQL text as UTF-16LE
83 write_utf16_string(&mut buf, sql);
84
85 buf.freeze()
86}
87
88/// SQL batch builder for more complex batches.
89///
90/// This can be used to build batches with multiple statements
91/// or to add headers for specific features.
92#[derive(Debug, Clone)]
93pub struct SqlBatch {
94 sql: String,
95}
96
97impl SqlBatch {
98 /// Create a new SQL batch.
99 #[must_use]
100 pub fn new(sql: impl Into<String>) -> Self {
101 Self { sql: sql.into() }
102 }
103
104 /// Get the SQL text.
105 #[must_use]
106 pub fn sql(&self) -> &str {
107 &self.sql
108 }
109
110 /// Encode the SQL batch to bytes.
111 #[must_use]
112 pub fn encode(&self) -> Bytes {
113 encode_sql_batch(&self.sql)
114 }
115}
116
117#[cfg(test)]
118#[allow(clippy::unwrap_used)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_encode_sql_batch() {
124 let sql = "SELECT 1";
125 let payload = encode_sql_batch(sql);
126
127 // ALL_HEADERS (22 bytes) + UTF-16LE encoded (8 chars * 2 bytes = 16 bytes) = 38 bytes
128 assert_eq!(payload.len(), 38);
129
130 // Verify ALL_HEADERS section
131 // Total length at bytes 0-3 (little-endian)
132 assert_eq!(&payload[0..4], &[22, 0, 0, 0]); // TotalLength = 22
133
134 // Header length at bytes 4-7
135 assert_eq!(&payload[4..8], &[18, 0, 0, 0]); // HeaderLength = 18
136
137 // Header type at bytes 8-9
138 assert_eq!(&payload[8..10], &[0x02, 0x00]); // Transaction descriptor
139
140 // Verify UTF-16LE SQL starts at byte 22
141 // 'S' = 0x53, 'E' = 0x45, etc.
142 assert_eq!(payload[22], b'S');
143 assert_eq!(payload[23], 0);
144 assert_eq!(payload[24], b'E');
145 assert_eq!(payload[25], 0);
146 }
147
148 #[test]
149 fn test_sql_batch_builder() {
150 let batch = SqlBatch::new("SELECT @@VERSION");
151 assert_eq!(batch.sql(), "SELECT @@VERSION");
152
153 let payload = batch.encode();
154 assert!(!payload.is_empty());
155 }
156
157 #[test]
158 fn test_empty_batch() {
159 let payload = encode_sql_batch("");
160 // Even empty SQL has ALL_HEADERS (22 bytes)
161 assert_eq!(payload.len(), 22);
162 }
163}