Skip to main content

tds_protocol/
sql_batch.rs

1//! SQL batch request encoding.
2//!
3//! This module provides encoding for SQL batch requests (packet type 0x01).
4//! Per MS-TDS spec, a SQL batch contains:
5//! - ALL_HEADERS section (required for TDS 7.2+)
6//! - SQL text encoded as UTF-16LE
7
8use bytes::{BufMut, Bytes, BytesMut};
9
10use crate::codec::write_utf16_string;
11use crate::prelude::*;
12
13/// Low-level `SQLBatch` wire encoders shared across the workspace crates.
14///
15/// Internal plumbing reached cross-crate only via [`crate::__private`]; not
16/// public API and exempt from semver guarantees (see #242). The public
17/// [`SqlBatch`] builder below is the user-facing API.
18pub(crate) mod sealed {
19    use super::*;
20
21    /// Encode a SQL batch request with auto-commit (no explicit transaction).
22    ///
23    /// The SQL batch packet payload includes:
24    /// 1. ALL_HEADERS section (required for TDS 7.2+)
25    /// 2. SQL text encoded as UTF-16LE
26    ///
27    /// This function returns the encoded payload (without the packet header).
28    /// For requests within an explicit transaction, use [`encode_sql_batch_with_transaction`].
29    ///
30    /// # Example
31    ///
32    /// ```
33    /// use tds_protocol::__private::encode_sql_batch;
34    ///
35    /// let sql = "SELECT * FROM users WHERE id = 1";
36    /// let payload = encode_sql_batch(sql);
37    ///
38    /// // Payload includes ALL_HEADERS + UTF-16LE encoded SQL
39    /// assert!(!payload.is_empty());
40    /// ```
41    #[must_use]
42    pub fn encode_sql_batch(sql: &str) -> Bytes {
43        encode_sql_batch_with_transaction(sql, 0)
44    }
45
46    /// Encode a SQL batch request with a transaction descriptor.
47    ///
48    /// Per MS-TDS spec, when executing within an explicit transaction:
49    /// - The `transaction_descriptor` MUST be the value returned by the server
50    ///   in the BeginTransaction EnvChange token.
51    /// - For auto-commit mode (no explicit transaction), use 0.
52    ///
53    /// # Arguments
54    ///
55    /// * `sql` - The SQL text to execute
56    /// * `transaction_descriptor` - The transaction descriptor from BeginTransaction EnvChange,
57    ///   or 0 for auto-commit mode.
58    ///
59    /// # Example
60    ///
61    /// ```
62    /// use tds_protocol::__private::encode_sql_batch_with_transaction;
63    ///
64    /// // Within a transaction with descriptor 0x1234567890ABCDEF
65    /// let sql = "INSERT INTO users VALUES (1, 'Alice')";
66    /// let tx_descriptor = 0x1234567890ABCDEF_u64;
67    /// let payload = encode_sql_batch_with_transaction(sql, tx_descriptor);
68    /// ```
69    #[must_use]
70    pub fn encode_sql_batch_with_transaction(sql: &str, transaction_descriptor: u64) -> Bytes {
71        // Capacity: ALL_HEADERS (22 bytes) + SQL UTF-16LE (sql.len() * 2)
72        let mut buf = BytesMut::with_capacity(22 + sql.len() * 2);
73
74        // ALL_HEADERS section (required for TDS 7.2+)
75        // Per MS-TDS spec: ALL_HEADERS = TotalLength + Headers
76        let all_headers_start = buf.len();
77        buf.put_u32_le(0); // Total length placeholder
78
79        // Transaction descriptor header (type 0x0002)
80        // Per MS-TDS 2.2.5.3: HeaderLength (4) + HeaderType (2) + TransactionDescriptor (8) + OutstandingRequestCount (4)
81        buf.put_u32_le(18); // Header length = 18 bytes
82        buf.put_u16_le(0x0002); // Header type: transaction descriptor
83        buf.put_u64_le(transaction_descriptor); // Transaction descriptor from BeginTransaction EnvChange
84        buf.put_u32_le(1); // Outstanding request count (1 for non-MARS connections)
85
86        // Fill in ALL_HEADERS total length
87        let all_headers_len = buf.len() - all_headers_start;
88        let len_bytes = (all_headers_len as u32).to_le_bytes();
89        buf[all_headers_start..all_headers_start + 4].copy_from_slice(&len_bytes);
90
91        // SQL text as UTF-16LE
92        write_utf16_string(&mut buf, sql);
93
94        buf.freeze()
95    }
96}
97
98// Keep `encode_sql_batch` reachable for the intra-crate `SqlBatch::encode`
99// caller below; off the public surface (see crate::__private).
100pub(crate) use sealed::encode_sql_batch;
101
102/// SQL batch builder for more complex batches.
103///
104/// This can be used to build batches with multiple statements
105/// or to add headers for specific features.
106#[derive(Debug, Clone)]
107pub struct SqlBatch {
108    sql: String,
109}
110
111impl SqlBatch {
112    /// Create a new SQL batch.
113    #[must_use]
114    pub fn new(sql: impl Into<String>) -> Self {
115        Self { sql: sql.into() }
116    }
117
118    /// Get the SQL text.
119    #[must_use]
120    pub fn sql(&self) -> &str {
121        &self.sql
122    }
123
124    /// Encode the SQL batch to bytes.
125    #[must_use]
126    pub fn encode(&self) -> Bytes {
127        encode_sql_batch(&self.sql)
128    }
129}
130
131#[cfg(test)]
132#[allow(clippy::unwrap_used)]
133mod tests {
134    use super::*;
135
136    #[test]
137    fn test_encode_sql_batch() {
138        let sql = "SELECT 1";
139        let payload = encode_sql_batch(sql);
140
141        // ALL_HEADERS (22 bytes) + UTF-16LE encoded (8 chars * 2 bytes = 16 bytes) = 38 bytes
142        assert_eq!(payload.len(), 38);
143
144        // Verify ALL_HEADERS section
145        // Total length at bytes 0-3 (little-endian)
146        assert_eq!(&payload[0..4], &[22, 0, 0, 0]); // TotalLength = 22
147
148        // Header length at bytes 4-7
149        assert_eq!(&payload[4..8], &[18, 0, 0, 0]); // HeaderLength = 18
150
151        // Header type at bytes 8-9
152        assert_eq!(&payload[8..10], &[0x02, 0x00]); // Transaction descriptor
153
154        // Verify UTF-16LE SQL starts at byte 22
155        // 'S' = 0x53, 'E' = 0x45, etc.
156        assert_eq!(payload[22], b'S');
157        assert_eq!(payload[23], 0);
158        assert_eq!(payload[24], b'E');
159        assert_eq!(payload[25], 0);
160    }
161
162    #[test]
163    fn test_sql_batch_builder() {
164        let batch = SqlBatch::new("SELECT @@VERSION");
165        assert_eq!(batch.sql(), "SELECT @@VERSION");
166
167        let payload = batch.encode();
168        assert!(!payload.is_empty());
169    }
170
171    #[test]
172    fn test_empty_batch() {
173        let payload = encode_sql_batch("");
174        // Even empty SQL has ALL_HEADERS (22 bytes)
175        assert_eq!(payload.len(), 22);
176    }
177}