tds_protocol/
sql_batch.rs

1//! SQL batch request encoding.
2//!
3//! This module provides encoding for SQL batch requests (packet type 0x01).
4//! Per MS-TDS spec, a SQL batch contains:
5//! - ALL_HEADERS section (required for TDS 7.2+)
6//! - SQL text encoded as UTF-16LE
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::codec::write_utf16_string;
11use crate::prelude::*;
12
13/// Encode a SQL batch request with auto-commit (no explicit transaction).
14///
15/// The SQL batch packet payload includes:
16/// 1. ALL_HEADERS section (required for TDS 7.2+)
17/// 2. SQL text encoded as UTF-16LE
18///
19/// This function returns the encoded payload (without the packet header).
20/// For requests within an explicit transaction, use [`encode_sql_batch_with_transaction`].
21///
22/// # Example
23///
24/// ```
25/// use tds_protocol::sql_batch::encode_sql_batch;
26///
27/// let sql = "SELECT * FROM users WHERE id = 1";
28/// let payload = encode_sql_batch(sql);
29///
30/// // Payload includes ALL_HEADERS + UTF-16LE encoded SQL
31/// assert!(!payload.is_empty());
32/// ```
33#[must_use]
34pub fn encode_sql_batch(sql: &str) -> Bytes {
35    encode_sql_batch_with_transaction(sql, 0)
36}
37
38/// Encode a SQL batch request with a transaction descriptor.
39///
40/// Per MS-TDS spec, when executing within an explicit transaction:
41/// - The `transaction_descriptor` MUST be the value returned by the server
42///   in the BeginTransaction EnvChange token.
43/// - For auto-commit mode (no explicit transaction), use 0.
44///
45/// # Arguments
46///
47/// * `sql` - The SQL text to execute
48/// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
49///   or 0 for auto-commit mode.
50///
51/// # Example
52///
53/// ```
54/// use tds_protocol::sql_batch::encode_sql_batch_with_transaction;
55///
56/// // Within a transaction with descriptor 0x1234567890ABCDEF
57/// let sql = "INSERT INTO users VALUES (1, 'Alice')";
58/// let tx_descriptor = 0x1234567890ABCDEF_u64;
59/// let payload = encode_sql_batch_with_transaction(sql, tx_descriptor);
60/// ```
61#[must_use]
62pub fn encode_sql_batch_with_transaction(sql: &str, transaction_descriptor: u64) -> Bytes {
63    // Capacity: ALL_HEADERS (22 bytes) + SQL UTF-16LE (sql.len() * 2)
64    let mut buf = BytesMut::with_capacity(22 + sql.len() * 2);
65
66    // ALL_HEADERS section (required for TDS 7.2+)
67    // Per MS-TDS spec: ALL_HEADERS = TotalLength + Headers
68    let all_headers_start = buf.len();
69    buf.put_u32_le(0); // Total length placeholder
70
71    // Transaction descriptor header (type 0x0002)
72    // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
73    buf.put_u32_le(18); // Header length = 18 bytes
74    buf.put_u16_le(0x0002); // Header type: transaction descriptor
75    buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
76    buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
77
78    // Fill in ALL_HEADERS total length
79    let all_headers_len = buf.len() - all_headers_start;
80    let len_bytes = (all_headers_len as u32).to_le_bytes();
81    buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
82
83    // SQL text as UTF-16LE
84    write_utf16_string(&mut buf, sql);
85
86    buf.freeze()
87}
88
89/// SQL batch builder for more complex batches.
90///
91/// This can be used to build batches with multiple statements
92/// or to add headers for specific features.
93#[derive(Debug, Clone)]
94pub struct SqlBatch {
95    sql: String,
96}
97
98impl SqlBatch {
99    /// Create a new SQL batch.
100    #[must_use]
101    pub fn new(sql: impl Into<String>) -> Self {
102        Self { sql: sql.into() }
103    }
104
105    /// Get the SQL text.
106    #[must_use]
107    pub fn sql(&self) -> &str {
108        &self.sql
109    }
110
111    /// Encode the SQL batch to bytes.
112    #[must_use]
113    pub fn encode(&self) -> Bytes {
114        encode_sql_batch(&self.sql)
115    }
116}
117
118#[cfg(test)]
119#[allow(clippy::unwrap_used)]
120mod tests {
121    use super::*;
122
123    #[test]
124    fn test_encode_sql_batch() {
125        let sql = "SELECT 1";
126        let payload = encode_sql_batch(sql);
127
128        // ALL_HEADERS (22 bytes) + UTF-16LE encoded (8 chars * 2 bytes = 16 bytes) = 38 bytes
129        assert_eq!(payload.len(), 38);
130
131        // Verify ALL_HEADERS section
132        // Total length at bytes 0-3 (little-endian)
133        assert_eq!(&payload[0..4], &[22, 0, 0, 0]); // TotalLength = 22
134
135        // Header length at bytes 4-7
136        assert_eq!(&payload[4..8], &[18, 0, 0, 0]); // HeaderLength = 18
137
138        // Header type at bytes 8-9
139        assert_eq!(&payload[8..10], &[0x02, 0x00]); // Transaction descriptor
140
141        // Verify UTF-16LE SQL starts at byte 22
142        // 'S' = 0x53, 'E' = 0x45, etc.
143        assert_eq!(payload[22], b'S');
144        assert_eq!(payload[23], 0);
145        assert_eq!(payload[24], b'E');
146        assert_eq!(payload[25], 0);
147    }
148
149    #[test]
150    fn test_sql_batch_builder() {
151        let batch = SqlBatch::new("SELECT @@VERSION");
152        assert_eq!(batch.sql(), "SELECT @@VERSION");
153
154        let payload = batch.encode();
155        assert!(!payload.is_empty());
156    }
157
158    #[test]
159    fn test_empty_batch() {
160        let payload = encode_sql_batch("");
161        // Even empty SQL has ALL_HEADERS (22 bytes)
162        assert_eq!(payload.len(), 22);
163    }
164}