tds_protocol/
crypto.rs

1//! Always Encrypted cryptography metadata for TDS protocol.
2//!
3//! This module defines the wire-level structures for SQL Server's Always Encrypted
4//! feature. When a query returns encrypted columns, SQL Server sends additional
5//! metadata describing how to decrypt the data.
6//!
7//! ## TDS Wire Format
8//!
9//! When Always Encrypted is enabled, the COLMETADATA token includes:
10//!
11//! 1. **CEK Table**: A table of Column Encryption Keys needed for the result set
12//! 2. **CryptoMetadata**: Per-column encryption information
13//!
14//! ```text
15//! COLMETADATA Token (with encryption):
16//! ┌─────────────────────────────────────────────────────────────────┐
17//! │ Column Count (2 bytes)                                          │
18//! ├─────────────────────────────────────────────────────────────────┤
19//! │ CEK Table (if encrypted columns present)                        │
20//! │ ├── CEK Count (2 bytes)                                         │
21//! │ ├── CEK Entry 1                                                 │
22//! │ │   ├── Database ID (4 bytes)                                   │
23//! │ │   ├── CEK ID (4 bytes)                                        │
24//! │ │   ├── CEK Version (4 bytes)                                   │
25//! │ │   ├── CEK MD Version (8 bytes)                                │
26//! │ │   ├── CEK Value Count (1 byte)                                │
27//! │ │   └── CEK Value(s)                                            │
28//! │ │       ├── Encrypted Value Length (2 bytes)                    │
29//! │ │       ├── Encrypted Value (variable)                          │
30//! │ │       ├── Key Store Name (B_VARCHAR)                          │
31//! │ │       ├── CMK Path (US_VARCHAR)                               │
32//! │ │       └── Algorithm (B_VARCHAR)                               │
33//! │ └── ...more CEK entries                                         │
34//! ├─────────────────────────────────────────────────────────────────┤
35//! │ Column Definitions                                              │
36//! │ ├── Column 1                                                    │
37//! │ │   ├── User Type (4 bytes)                                     │
38//! │ │   ├── Flags (2 bytes) - includes encryption flag              │
39//! │ │   ├── Type ID (1 byte)                                        │
40//! │ │   ├── Type Info (variable)                                    │
41//! │ │   ├── CryptoMetadata (if encrypted)                           │
42//! │ │   │   ├── CEK Table Ordinal (2 bytes)                         │
43//! │ │   │   ├── Algorithm ID (1 byte)                               │
44//! │ │   │   ├── Encryption Type (1 byte)                            │
45//! │ │   │   └── Normalization Version (1 byte)                      │
46//! │ │   └── Column Name (B_VARCHAR)                                 │
47//! │ └── ...more columns                                             │
48//! └─────────────────────────────────────────────────────────────────┘
49//! ```
50
51use bytes::{Buf, Bytes};
52
53use crate::codec::{read_b_varchar, read_us_varchar};
54use crate::error::ProtocolError;
55use crate::prelude::*;
56
57/// Column flags bit indicating the column is encrypted.
58pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
59
60/// Algorithm ID for AEAD_AES_256_CBC_HMAC_SHA256.
61pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
62
63/// Encryption type: Deterministic.
64pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
65
66/// Encryption type: Randomized.
67pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
68
69/// Current normalization rule version.
70pub const NORMALIZATION_RULE_VERSION: u8 = 1;
71
72/// Column Encryption Key table entry.
73///
74/// This represents a single CEK entry in the CEK table sent with COLMETADATA.
75/// Multiple columns may share the same CEK.
76#[derive(Debug, Clone)]
77pub struct CekTableEntry {
78    /// Database ID where the CEK is defined.
79    pub database_id: u32,
80    /// CEK ID within the database.
81    pub cek_id: u32,
82    /// CEK version (incremented on key rotation).
83    pub cek_version: u32,
84    /// Metadata version (changes with any metadata update).
85    pub cek_md_version: u64,
86    /// CEK value entries (usually one, but may have multiple for key rotation).
87    pub values: Vec<CekValue>,
88}
89
90/// A single CEK value (encrypted by CMK).
91///
92/// A CEK may have multiple values when key rotation is in progress,
93/// with different CMKs encrypting the same CEK.
94#[derive(Debug, Clone)]
95pub struct CekValue {
96    /// The encrypted CEK bytes.
97    pub encrypted_value: Bytes,
98    /// Name of the key store provider (e.g., "AZURE_KEY_VAULT").
99    pub key_store_provider_name: String,
100    /// Path to the Column Master Key in the key store.
101    pub cmk_path: String,
102    /// Asymmetric algorithm used to encrypt the CEK (e.g., "RSA_OAEP").
103    pub encryption_algorithm: String,
104}
105
106/// Per-column encryption metadata.
107///
108/// This metadata is present for each encrypted column and describes
109/// how to decrypt the column data.
110#[derive(Debug, Clone)]
111pub struct CryptoMetadata {
112    /// Index into the CEK table (0-based).
113    pub cek_table_ordinal: u16,
114    /// Encryption algorithm ID.
115    pub algorithm_id: u8,
116    /// Encryption type (deterministic or randomized).
117    pub encryption_type: EncryptionTypeWire,
118    /// Normalization rule version.
119    pub normalization_version: u8,
120}
121
122/// Wire-level encryption type.
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum EncryptionTypeWire {
125    /// Deterministic encryption (value 1).
126    Deterministic,
127    /// Randomized encryption (value 2).
128    Randomized,
129}
130
131impl EncryptionTypeWire {
132    /// Create from wire value.
133    #[must_use]
134    pub fn from_u8(value: u8) -> Option<Self> {
135        match value {
136            ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
137            ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
138            _ => None,
139        }
140    }
141
142    /// Convert to wire value.
143    #[must_use]
144    pub fn to_u8(self) -> u8 {
145        match self {
146            Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
147            Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
148        }
149    }
150}
151
152/// CEK table containing all Column Encryption Keys needed for a result set.
153#[derive(Debug, Clone, Default)]
154pub struct CekTable {
155    /// CEK entries.
156    pub entries: Vec<CekTableEntry>,
157}
158
159impl CekTable {
160    /// Create an empty CEK table.
161    #[must_use]
162    pub fn new() -> Self {
163        Self::default()
164    }
165
166    /// Get a CEK entry by ordinal.
167    #[must_use]
168    pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
169        self.entries.get(ordinal as usize)
170    }
171
172    /// Check if the table is empty.
173    #[must_use]
174    pub fn is_empty(&self) -> bool {
175        self.entries.is_empty()
176    }
177
178    /// Get the number of entries.
179    #[must_use]
180    pub fn len(&self) -> usize {
181        self.entries.len()
182    }
183
184    /// Decode a CEK table from the wire format.
185    ///
186    /// # Wire Format
187    ///
188    /// ```text
189    /// CEK_TABLE:
190    ///   cek_count: USHORT (2 bytes)
191    ///   entries: CEK_ENTRY[cek_count]
192    ///
193    /// CEK_ENTRY:
194    ///   database_id: DWORD (4 bytes)
195    ///   cek_id: DWORD (4 bytes)
196    ///   cek_version: DWORD (4 bytes)
197    ///   cek_md_version: ULONGLONG (8 bytes)
198    ///   value_count: BYTE (1 byte)
199    ///   values: CEK_VALUE[value_count]
200    ///
201    /// CEK_VALUE:
202    ///   encrypted_value_length: USHORT (2 bytes)
203    ///   encrypted_value: BYTE[encrypted_value_length]
204    ///   key_store_name: B_VARCHAR
205    ///   cmk_path: US_VARCHAR
206    ///   algorithm: B_VARCHAR
207    /// ```
208    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
209        if src.remaining() < 2 {
210            return Err(ProtocolError::UnexpectedEof);
211        }
212
213        let cek_count = src.get_u16_le() as usize;
214
215        let mut entries = Vec::with_capacity(cek_count);
216
217        for _ in 0..cek_count {
218            let entry = CekTableEntry::decode(src)?;
219            entries.push(entry);
220        }
221
222        Ok(Self { entries })
223    }
224}
225
226impl CekTableEntry {
227    /// Decode a CEK table entry from the wire format.
228    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
229        // database_id (4) + cek_id (4) + cek_version (4) + cek_md_version (8) + value_count (1)
230        if src.remaining() < 21 {
231            return Err(ProtocolError::UnexpectedEof);
232        }
233
234        let database_id = src.get_u32_le();
235        let cek_id = src.get_u32_le();
236        let cek_version = src.get_u32_le();
237        let cek_md_version = src.get_u64_le();
238        let value_count = src.get_u8() as usize;
239
240        let mut values = Vec::with_capacity(value_count);
241
242        for _ in 0..value_count {
243            let value = CekValue::decode(src)?;
244            values.push(value);
245        }
246
247        Ok(Self {
248            database_id,
249            cek_id,
250            cek_version,
251            cek_md_version,
252            values,
253        })
254    }
255
256    /// Get the first (primary) encrypted value.
257    #[must_use]
258    pub fn primary_value(&self) -> Option<&CekValue> {
259        self.values.first()
260    }
261}
262
263impl CekValue {
264    /// Decode a CEK value from the wire format.
265    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
266        // encrypted_value_length (2 bytes)
267        if src.remaining() < 2 {
268            return Err(ProtocolError::UnexpectedEof);
269        }
270
271        let encrypted_value_length = src.get_u16_le() as usize;
272
273        if src.remaining() < encrypted_value_length {
274            return Err(ProtocolError::UnexpectedEof);
275        }
276
277        let encrypted_value = src.copy_to_bytes(encrypted_value_length);
278
279        // key_store_name (B_VARCHAR)
280        let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
281
282        // cmk_path (US_VARCHAR)
283        let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
284
285        // algorithm (B_VARCHAR)
286        let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
287
288        Ok(Self {
289            encrypted_value,
290            key_store_provider_name,
291            cmk_path,
292            encryption_algorithm,
293        })
294    }
295}
296
297impl CryptoMetadata {
298    /// Size of crypto metadata in bytes.
299    pub const SIZE: usize = 5; // ordinal (2) + algorithm (1) + enc_type (1) + norm_version (1)
300
301    /// Decode crypto metadata from the wire format.
302    pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
303        if src.remaining() < Self::SIZE {
304            return Err(ProtocolError::UnexpectedEof);
305        }
306
307        let cek_table_ordinal = src.get_u16_le();
308        let algorithm_id = src.get_u8();
309        let encryption_type_byte = src.get_u8();
310        let normalization_version = src.get_u8();
311
312        let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
313            ProtocolError::InvalidField {
314                field: "encryption_type",
315                value: encryption_type_byte as u32,
316            },
317        )?;
318
319        Ok(Self {
320            cek_table_ordinal,
321            algorithm_id,
322            encryption_type,
323            normalization_version,
324        })
325    }
326
327    /// Check if this uses the standard AEAD algorithm.
328    #[must_use]
329    pub fn is_aead_aes_256(&self) -> bool {
330        self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
331    }
332
333    /// Check if this uses deterministic encryption.
334    #[must_use]
335    pub fn is_deterministic(&self) -> bool {
336        self.encryption_type == EncryptionTypeWire::Deterministic
337    }
338
339    /// Check if this uses randomized encryption.
340    #[must_use]
341    pub fn is_randomized(&self) -> bool {
342        self.encryption_type == EncryptionTypeWire::Randomized
343    }
344}
345
346/// Extended column metadata with encryption information.
347///
348/// This combines the base column metadata with optional crypto metadata
349/// for Always Encrypted columns.
350#[derive(Debug, Clone, Default)]
351pub struct ColumnCryptoInfo {
352    /// Crypto metadata (if column is encrypted).
353    pub crypto_metadata: Option<CryptoMetadata>,
354}
355
356impl ColumnCryptoInfo {
357    /// Create info for an unencrypted column.
358    #[must_use]
359    pub fn unencrypted() -> Self {
360        Self {
361            crypto_metadata: None,
362        }
363    }
364
365    /// Create info for an encrypted column.
366    #[must_use]
367    pub fn encrypted(metadata: CryptoMetadata) -> Self {
368        Self {
369            crypto_metadata: Some(metadata),
370        }
371    }
372
373    /// Check if this column is encrypted.
374    #[must_use]
375    pub fn is_encrypted(&self) -> bool {
376        self.crypto_metadata.is_some()
377    }
378}
379
380/// Check if a column flags value indicates encryption.
381#[must_use]
382pub fn is_column_encrypted(flags: u16) -> bool {
383    (flags & COLUMN_FLAG_ENCRYPTED) != 0
384}
385
386#[cfg(test)]
387#[allow(clippy::unwrap_used, clippy::expect_used)]
388mod tests {
389    use super::*;
390    use bytes::BytesMut;
391
392    #[test]
393    fn test_encryption_type_wire_roundtrip() {
394        assert_eq!(
395            EncryptionTypeWire::from_u8(1),
396            Some(EncryptionTypeWire::Deterministic)
397        );
398        assert_eq!(
399            EncryptionTypeWire::from_u8(2),
400            Some(EncryptionTypeWire::Randomized)
401        );
402        assert_eq!(EncryptionTypeWire::from_u8(0), None);
403        assert_eq!(EncryptionTypeWire::from_u8(99), None);
404
405        assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
406        assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
407    }
408
409    #[test]
410    fn test_crypto_metadata_decode() {
411        let data = [
412            0x00, 0x00, // cek_table_ordinal = 0
413            0x02, // algorithm_id = AEAD_AES_256_CBC_HMAC_SHA256
414            0x01, // encryption_type = Deterministic
415            0x01, // normalization_version = 1
416        ];
417
418        let mut cursor: &[u8] = &data;
419        let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
420
421        assert_eq!(metadata.cek_table_ordinal, 0);
422        assert_eq!(
423            metadata.algorithm_id,
424            ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
425        );
426        assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
427        assert_eq!(metadata.normalization_version, 1);
428        assert!(metadata.is_aead_aes_256());
429        assert!(metadata.is_deterministic());
430        assert!(!metadata.is_randomized());
431    }
432
433    #[test]
434    fn test_cek_value_decode() {
435        let mut data = BytesMut::new();
436
437        // encrypted_value_length = 4
438        data.extend_from_slice(&[0x04, 0x00]);
439        // encrypted_value = [0xDE, 0xAD, 0xBE, 0xEF]
440        data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
441        // key_store_name = "TEST" (B_VARCHAR: 1 byte len + utf16le)
442        data.extend_from_slice(&[0x04]); // 4 chars
443        data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
444        // cmk_path = "key1" (US_VARCHAR: 2 byte len + utf16le)
445        data.extend_from_slice(&[0x04, 0x00]); // 4 chars
446        data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
447        // algorithm = "RSA" (B_VARCHAR)
448        data.extend_from_slice(&[0x03]); // 3 chars
449        data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
450
451        let mut cursor: &[u8] = &data;
452        let value = CekValue::decode(&mut cursor).unwrap();
453
454        assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
455        assert_eq!(value.key_store_provider_name, "TEST");
456        assert_eq!(value.cmk_path, "key1");
457        assert_eq!(value.encryption_algorithm, "RSA");
458    }
459
460    #[test]
461    fn test_cek_table_entry_decode() {
462        let mut data = BytesMut::new();
463
464        // database_id = 1
465        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
466        // cek_id = 2
467        data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
468        // cek_version = 1
469        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
470        // cek_md_version = 100
471        data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
472        // value_count = 1
473        data.extend_from_slice(&[0x01]);
474
475        // CEK value
476        data.extend_from_slice(&[0x04, 0x00]); // encrypted_value_length = 4
477        data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); // encrypted_value
478        data.extend_from_slice(&[0x02]); // key_store_name length = 2
479        data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); // "KS"
480        data.extend_from_slice(&[0x01, 0x00]); // cmk_path length = 1
481        data.extend_from_slice(&[b'P', 0x00]); // "P"
482        data.extend_from_slice(&[0x01]); // algorithm length = 1
483        data.extend_from_slice(&[b'A', 0x00]); // "A"
484
485        let mut cursor: &[u8] = &data;
486        let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
487
488        assert_eq!(entry.database_id, 1);
489        assert_eq!(entry.cek_id, 2);
490        assert_eq!(entry.cek_version, 1);
491        assert_eq!(entry.cek_md_version, 100);
492        assert_eq!(entry.values.len(), 1);
493
494        let value = entry.primary_value().expect("should have primary value");
495        assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
496    }
497
498    #[test]
499    fn test_cek_table_decode() {
500        let mut data = BytesMut::new();
501
502        // cek_count = 1
503        data.extend_from_slice(&[0x01, 0x00]);
504
505        // CEK entry
506        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // database_id
507        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // cek_id
508        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); // cek_version
509        data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); // cek_md_version
510        data.extend_from_slice(&[0x01]); // value_count
511
512        // CEK value
513        data.extend_from_slice(&[0x02, 0x00]); // encrypted_value_length = 2
514        data.extend_from_slice(&[0xAB, 0xCD]); // encrypted_value
515        data.extend_from_slice(&[0x01]); // key_store_name = "K"
516        data.extend_from_slice(&[b'K', 0x00]);
517        data.extend_from_slice(&[0x01, 0x00]); // cmk_path = "P"
518        data.extend_from_slice(&[b'P', 0x00]);
519        data.extend_from_slice(&[0x01]); // algorithm = "A"
520        data.extend_from_slice(&[b'A', 0x00]);
521
522        let mut cursor: &[u8] = &data;
523        let table = CekTable::decode(&mut cursor).expect("should decode table");
524
525        assert_eq!(table.len(), 1);
526        assert!(!table.is_empty());
527
528        let entry = table.get(0).expect("should have first entry");
529        assert_eq!(entry.database_id, 1);
530    }
531
532    #[test]
533    fn test_is_column_encrypted() {
534        assert!(!is_column_encrypted(0x0000));
535        assert!(!is_column_encrypted(0x0001)); // nullable
536        assert!(is_column_encrypted(0x0800)); // encrypted flag
537        assert!(is_column_encrypted(0x0801)); // encrypted + nullable
538    }
539
540    #[test]
541    fn test_column_crypto_info() {
542        let unencrypted = ColumnCryptoInfo::unencrypted();
543        assert!(!unencrypted.is_encrypted());
544
545        let metadata = CryptoMetadata {
546            cek_table_ordinal: 0,
547            algorithm_id: 2,
548            encryption_type: EncryptionTypeWire::Randomized,
549            normalization_version: 1,
550        };
551        let encrypted = ColumnCryptoInfo::encrypted(metadata);
552        assert!(encrypted.is_encrypted());
553    }
554}