tds_protocol/
crypto.rs

1//! Always Encrypted cryptography metadata for TDS protocol.
2//!
3//! This module defines the wire-level structures for SQL Server's Always Encrypted
4//! feature. When a query returns encrypted columns, SQL Server sends additional
5//! metadata describing how to decrypt the data.
6//!
7//! ## TDS Wire Format
8//!
9//! When Always Encrypted is enabled, the COLMETADATA token includes:
10//!
11//! 1. **CEK Table**: A table of Column Encryption Keys needed for the result set
12//! 2. **CryptoMetadata**: Per-column encryption information
13//!
14//! ```text
15//! COLMETADATA Token (with encryption):
16//! ┌─────────────────────────────────────────────────────────────────┐
17//! │ Column Count (2 bytes)                                          │
18//! ├─────────────────────────────────────────────────────────────────┤
19//! │ CEK Table (if encrypted columns present)                        │
20//! │ ├── CEK Count (2 bytes)                                         │
21//! │ ├── CEK Entry 1                                                 │
22//! │ │   ├── Database ID (4 bytes)                                   │
23//! │ │   ├── CEK ID (4 bytes)                                        │
24//! │ │   ├── CEK Version (4 bytes)                                   │
25//! │ │   ├── CEK MD Version (8 bytes)                                │
26//! │ │   ├── CEK Value Count (1 byte)                                │
27//! │ │   └── CEK Value(s)                                            │
28//! │ │       ├── Encrypted Value Length (2 bytes)                    │
29//! │ │       ├── Encrypted Value (variable)                          │
30//! │ │       ├── Key Store Name (B_VARCHAR)                          │
31//! │ │       ├── CMK Path (US_VARCHAR)                               │
32//! │ │       └── Algorithm (B_VARCHAR)                               │
33//! │ └── ...more CEK entries                                         │
34//! ├─────────────────────────────────────────────────────────────────┤
35//! │ Column Definitions                                              │
36//! │ ├── Column 1                                                    │
37//! │ │   ├── User Type (4 bytes)                                     │
38//! │ │   ├── Flags (2 bytes) - includes encryption flag              │
39//! │ │   ├── Type ID (1 byte)                                        │
40//! │ │   ├── Type Info (variable)                                    │
41//! │ │   ├── CryptoMetadata (if encrypted)                           │
42//! │ │   │   ├── CEK Table Ordinal (2 bytes)                         │
43//! │ │   │   ├── Algorithm ID (1 byte)                               │
44//! │ │   │   ├── Encryption Type (1 byte)                            │
45//! │ │   │   └── Normalization Version (1 byte)                      │
46//! │ │   └── Column Name (B_VARCHAR)                                 │
47//! │ └── ...more columns                                             │
48//! └─────────────────────────────────────────────────────────────────┘
49//! ```
50
51use bytes::{Buf, Bytes};
52
53#[cfg(not(feature = "std"))]
54use alloc::string::String;
55#[cfg(not(feature = "std"))]
56use alloc::vec::Vec;
57
58use crate::codec::{read_b_varchar, read_us_varchar};
59use crate::error::ProtocolError;
60
61/// Column flags bit indicating the column is encrypted.
62pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
63
64/// Algorithm ID for AEAD_AES_256_CBC_HMAC_SHA256.
65pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
66
67/// Encryption type: Deterministic.
68pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
69
70/// Encryption type: Randomized.
71pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
72
73/// Current normalization rule version.
74pub const NORMALIZATION_RULE_VERSION: u8 = 1;
75
76/// Column Encryption Key table entry.
77///
78/// This represents a single CEK entry in the CEK table sent with COLMETADATA.
79/// Multiple columns may share the same CEK.
80#[derive(Debug, Clone)]
81pub struct CekTableEntry {
82    /// Database ID where the CEK is defined.
83    pub database_id: u32,
84    /// CEK ID within the database.
85    pub cek_id: u32,
86    /// CEK version (incremented on key rotation).
87    pub cek_version: u32,
88    /// Metadata version (changes with any metadata update).
89    pub cek_md_version: u64,
90    /// CEK value entries (usually one, but may have multiple for key rotation).
91    pub values: Vec<CekValue>,
92}
93
94/// A single CEK value (encrypted by CMK).
95///
96/// A CEK may have multiple values when key rotation is in progress,
97/// with different CMKs encrypting the same CEK.
98#[derive(Debug, Clone)]
99pub struct CekValue {
100    /// The encrypted CEK bytes.
101    pub encrypted_value: Bytes,
102    /// Name of the key store provider (e.g., "AZURE_KEY_VAULT").
103    pub key_store_provider_name: String,
104    /// Path to the Column Master Key in the key store.
105    pub cmk_path: String,
106    /// Asymmetric algorithm used to encrypt the CEK (e.g., "RSA_OAEP").
107    pub encryption_algorithm: String,
108}
109
110/// Per-column encryption metadata.
111///
112/// This metadata is present for each encrypted column and describes
113/// how to decrypt the column data.
114#[derive(Debug, Clone)]
115pub struct CryptoMetadata {
116    /// Index into the CEK table (0-based).
117    pub cek_table_ordinal: u16,
118    /// Encryption algorithm ID.
119    pub algorithm_id: u8,
120    /// Encryption type (deterministic or randomized).
121    pub encryption_type: EncryptionTypeWire,
122    /// Normalization rule version.
123    pub normalization_version: u8,
124}
125
126/// Wire-level encryption type.
127#[derive(Debug, Clone, Copy, PartialEq, Eq)]
128pub enum EncryptionTypeWire {
129    /// Deterministic encryption (value 1).
130    Deterministic,
131    /// Randomized encryption (value 2).
132    Randomized,
133}
134
135impl EncryptionTypeWire {
136    /// Create from wire value.
137    #[must_use]
138    pub fn from_u8(value: u8) -> Option<Self> {
139        match value {
140            ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
141            ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
142            _ => None,
143        }
144    }
145
146    /// Convert to wire value.
147    #[must_use]
148    pub fn to_u8(self) -> u8 {
149        match self {
150            Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
151            Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
152        }
153    }
154}
155
156/// CEK table containing all Column Encryption Keys needed for a result set.
157#[derive(Debug, Clone, Default)]
158pub struct CekTable {
159    /// CEK entries.
160    pub entries: Vec<CekTableEntry>,
161}
162
163impl CekTable {
164    /// Create an empty CEK table.
165    #[must_use]
166    pub fn new() -> Self {
167        Self::default()
168    }
169
170    /// Get a CEK entry by ordinal.
171    #[must_use]
172    pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
173        self.entries.get(ordinal as usize)
174    }
175
176    /// Check if the table is empty.
177    #[must_use]
178    pub fn is_empty(&self) -> bool {
179        self.entries.is_empty()
180    }
181
182    /// Get the number of entries.
183    #[must_use]
184    pub fn len(&self) -> usize {
185        self.entries.len()
186    }
187
188    /// Decode a CEK table from the wire format.
189    ///
190    /// # Wire Format
191    ///
192    /// ```text
193    /// CEK_TABLE:
194    ///   cek_count: USHORT (2 bytes)
195    ///   entries: CEK_ENTRY[cek_count]
196    ///
197    /// CEK_ENTRY:
198    ///   database_id: DWORD (4 bytes)
199    ///   cek_id: DWORD (4 bytes)
200    ///   cek_version: DWORD (4 bytes)
201    ///   cek_md_version: ULONGLONG (8 bytes)
202    ///   value_count: BYTE (1 byte)
203    ///   values: CEK_VALUE[value_count]
204    ///
205    /// CEK_VALUE:
206    ///   encrypted_value_length: USHORT (2 bytes)
207    ///   encrypted_value: BYTE[encrypted_value_length]
208    ///   key_store_name: B_VARCHAR
209    ///   cmk_path: US_VARCHAR
210    ///   algorithm: B_VARCHAR
211    /// ```
212    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
213        if src.remaining() < 2 {
214            return Err(ProtocolError::UnexpectedEof);
215        }
216
217        let cek_count = src.get_u16_le() as usize;
218
219        let mut entries = Vec::with_capacity(cek_count);
220
221        for _ in 0..cek_count {
222            let entry = CekTableEntry::decode(src)?;
223            entries.push(entry);
224        }
225
226        Ok(Self { entries })
227    }
228}
229
230impl CekTableEntry {
231    /// Decode a CEK table entry from the wire format.
232    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
233        // database_id (4) + cek_id (4) + cek_version (4) + cek_md_version (8) + value_count (1)
234        if src.remaining() < 21 {
235            return Err(ProtocolError::UnexpectedEof);
236        }
237
238        let database_id = src.get_u32_le();
239        let cek_id = src.get_u32_le();
240        let cek_version = src.get_u32_le();
241        let cek_md_version = src.get_u64_le();
242        let value_count = src.get_u8() as usize;
243
244        let mut values = Vec::with_capacity(value_count);
245
246        for _ in 0..value_count {
247            let value = CekValue::decode(src)?;
248            values.push(value);
249        }
250
251        Ok(Self {
252            database_id,
253            cek_id,
254            cek_version,
255            cek_md_version,
256            values,
257        })
258    }
259
260    /// Get the first (primary) encrypted value.
261    #[must_use]
262    pub fn primary_value(&self) -> Option<&CekValue> {
263        self.values.first()
264    }
265}
266
267impl CekValue {
268    /// Decode a CEK value from the wire format.
269    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
270        // encrypted_value_length (2 bytes)
271        if src.remaining() < 2 {
272            return Err(ProtocolError::UnexpectedEof);
273        }
274
275        let encrypted_value_length = src.get_u16_le() as usize;
276
277        if src.remaining() < encrypted_value_length {
278            return Err(ProtocolError::UnexpectedEof);
279        }
280
281        let encrypted_value = src.copy_to_bytes(encrypted_value_length);
282
283        // key_store_name (B_VARCHAR)
284        let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
285
286        // cmk_path (US_VARCHAR)
287        let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
288
289        // algorithm (B_VARCHAR)
290        let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
291
292        Ok(Self {
293            encrypted_value,
294            key_store_provider_name,
295            cmk_path,
296            encryption_algorithm,
297        })
298    }
299}
300
301impl CryptoMetadata {
302    /// Size of crypto metadata in bytes.
303    pub const SIZE: usize = 5; // ordinal (2) + algorithm (1) + enc_type (1) + norm_version (1)
304
305    /// Decode crypto metadata from the wire format.
306    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
307        if src.remaining() < Self::SIZE {
308            return Err(ProtocolError::UnexpectedEof);
309        }
310
311        let cek_table_ordinal = src.get_u16_le();
312        let algorithm_id = src.get_u8();
313        let encryption_type_byte = src.get_u8();
314        let normalization_version = src.get_u8();
315
316        let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
317            ProtocolError::InvalidField {
318                field: "encryption_type",
319                value: encryption_type_byte as u32,
320            },
321        )?;
322
323        Ok(Self {
324            cek_table_ordinal,
325            algorithm_id,
326            encryption_type,
327            normalization_version,
328        })
329    }
330
331    /// Check if this uses the standard AEAD algorithm.
332    #[must_use]
333    pub fn is_aead_aes_256(&self) -> bool {
334        self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
335    }
336
337    /// Check if this uses deterministic encryption.
338    #[must_use]
339    pub fn is_deterministic(&self) -> bool {
340        self.encryption_type == EncryptionTypeWire::Deterministic
341    }
342
343    /// Check if this uses randomized encryption.
344    #[must_use]
345    pub fn is_randomized(&self) -> bool {
346        self.encryption_type == EncryptionTypeWire::Randomized
347    }
348}
349
350/// Extended column metadata with encryption information.
351///
352/// This combines the base column metadata with optional crypto metadata
353/// for Always Encrypted columns.
354#[derive(Debug, Clone, Default)]
355pub struct ColumnCryptoInfo {
356    /// Crypto metadata (if column is encrypted).
357    pub crypto_metadata: Option<CryptoMetadata>,
358}
359
360impl ColumnCryptoInfo {
361    /// Create info for an unencrypted column.
362    #[must_use]
363    pub fn unencrypted() -> Self {
364        Self {
365            crypto_metadata: None,
366        }
367    }
368
369    /// Create info for an encrypted column.
370    #[must_use]
371    pub fn encrypted(metadata: CryptoMetadata) -> Self {
372        Self {
373            crypto_metadata: Some(metadata),
374        }
375    }
376
377    /// Check if this column is encrypted.
378    #[must_use]
379    pub fn is_encrypted(&self) -> bool {
380        self.crypto_metadata.is_some()
381    }
382}
383
384/// Check if a column flags value indicates encryption.
385#[must_use]
386pub fn is_column_encrypted(flags: u16) -> bool {
387    (flags & COLUMN_FLAG_ENCRYPTED) != 0
388}
389
390#[cfg(test)]
391#[allow(clippy::unwrap_used, clippy::expect_used)]
392mod tests {
393    use super::*;
394    use bytes::BytesMut;
395
396    #[test]
397    fn test_encryption_type_wire_roundtrip() {
398        assert_eq!(
399            EncryptionTypeWire::from_u8(1),
400            Some(EncryptionTypeWire::Deterministic)
401        );
402        assert_eq!(
403            EncryptionTypeWire::from_u8(2),
404            Some(EncryptionTypeWire::Randomized)
405        );
406        assert_eq!(EncryptionTypeWire::from_u8(0), None);
407        assert_eq!(EncryptionTypeWire::from_u8(99), None);
408
409        assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
410        assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
411    }
412
413    #[test]
414    fn test_crypto_metadata_decode() {
415        let data = [
416            0x00, 0x00, // cek_table_ordinal = 0
417            0x02, // algorithm_id = AEAD_AES_256_CBC_HMAC_SHA256
418            0x01, // encryption_type = Deterministic
419            0x01, // normalization_version = 1
420        ];
421
422        let mut cursor: &[u8] = &data;
423        let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
424
425        assert_eq!(metadata.cek_table_ordinal, 0);
426        assert_eq!(
427            metadata.algorithm_id,
428            ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
429        );
430        assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
431        assert_eq!(metadata.normalization_version, 1);
432        assert!(metadata.is_aead_aes_256());
433        assert!(metadata.is_deterministic());
434        assert!(!metadata.is_randomized());
435    }
436
437    #[test]
438    fn test_cek_value_decode() {
439        let mut data = BytesMut::new();
440
441        // encrypted_value_length = 4
442        data.extend_from_slice(&[0x04, 0x00]);
443        // encrypted_value = [0xDE, 0xAD, 0xBE, 0xEF]
444        data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
445        // key_store_name = "TEST" (B_VARCHAR: 1 byte len + utf16le)
446        data.extend_from_slice(&[0x04]); // 4 chars
447        data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
448        // cmk_path = "key1" (US_VARCHAR: 2 byte len + utf16le)
449        data.extend_from_slice(&[0x04, 0x00]); // 4 chars
450        data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
451        // algorithm = "RSA" (B_VARCHAR)
452        data.extend_from_slice(&[0x03]); // 3 chars
453        data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
454
455        let mut cursor: &[u8] = &data;
456        let value = CekValue::decode(&mut cursor).unwrap();
457
458        assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
459        assert_eq!(value.key_store_provider_name, "TEST");
460        assert_eq!(value.cmk_path, "key1");
461        assert_eq!(value.encryption_algorithm, "RSA");
462    }
463
464    #[test]
465    fn test_cek_table_entry_decode() {
466        let mut data = BytesMut::new();
467
468        // database_id = 1
469        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
470        // cek_id = 2
471        data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
472        // cek_version = 1
473        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
474        // cek_md_version = 100
475        data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
476        // value_count = 1
477        data.extend_from_slice(&[0x01]);
478
479        // CEK value
480        data.extend_from_slice(&[0x04, 0x00]); // encrypted_value_length = 4
481        data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); // encrypted_value
482        data.extend_from_slice(&[0x02]); // key_store_name length = 2
483        data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); // "KS"
484        data.extend_from_slice(&[0x01, 0x00]); // cmk_path length = 1
485        data.extend_from_slice(&[b'P', 0x00]); // "P"
486        data.extend_from_slice(&[0x01]); // algorithm length = 1
487        data.extend_from_slice(&[b'A', 0x00]); // "A"
488
489        let mut cursor: &[u8] = &data;
490        let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
491
492        assert_eq!(entry.database_id, 1);
493        assert_eq!(entry.cek_id, 2);
494        assert_eq!(entry.cek_version, 1);
495        assert_eq!(entry.cek_md_version, 100);
496        assert_eq!(entry.values.len(), 1);
497
498        let value = entry.primary_value().expect("should have primary value");
499        assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
500    }
501
502    #[test]
503    fn test_cek_table_decode() {
504        let mut data = BytesMut::new();
505
506        // cek_count = 1
507        data.extend_from_slice(&[0x01, 0x00]);
508
509        // CEK entry
510        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // database_id
511        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // cek_id
512        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // cek_version
513        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); // cek_md_version
514        data.extend_from_slice(&[0x01]); // value_count
515
516        // CEK value
517        data.extend_from_slice(&[0x02, 0x00]); // encrypted_value_length = 2
518        data.extend_from_slice(&[0xAB, 0xCD]); // encrypted_value
519        data.extend_from_slice(&[0x01]); // key_store_name = "K"
520        data.extend_from_slice(&[b'K', 0x00]);
521        data.extend_from_slice(&[0x01, 0x00]); // cmk_path = "P"
522        data.extend_from_slice(&[b'P', 0x00]);
523        data.extend_from_slice(&[0x01]); // algorithm = "A"
524        data.extend_from_slice(&[b'A', 0x00]);
525
526        let mut cursor: &[u8] = &data;
527        let table = CekTable::decode(&mut cursor).expect("should decode table");
528
529        assert_eq!(table.len(), 1);
530        assert!(!table.is_empty());
531
532        let entry = table.get(0).expect("should have first entry");
533        assert_eq!(entry.database_id, 1);
534    }
535
536    #[test]
537    fn test_is_column_encrypted() {
538        assert!(!is_column_encrypted(0x0000));
539        assert!(!is_column_encrypted(0x0001)); // nullable
540        assert!(is_column_encrypted(0x0800)); // encrypted flag
541        assert!(is_column_encrypted(0x0801)); // encrypted + nullable
542    }
543
544    #[test]
545    fn test_column_crypto_info() {
546        let unencrypted = ColumnCryptoInfo::unencrypted();
547        assert!(!unencrypted.is_encrypted());
548
549        let metadata = CryptoMetadata {
550            cek_table_ordinal: 0,
551            algorithm_id: 2,
552            encryption_type: EncryptionTypeWire::Randomized,
553            normalization_version: 1,
554        };
555        let encrypted = ColumnCryptoInfo::encrypted(metadata);
556        assert!(encrypted.is_encrypted());
557    }
558}