1use bytes::{Buf, Bytes};
52
53use crate::codec::{read_b_varchar, read_us_varchar};
54use crate::error::ProtocolError;
55use crate::prelude::*;
56
57pub const COLUMN_FLAG_ENCRYPTED: u16 = 0x0800;
59
60pub const ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256: u8 = 2;
62
63pub const ENCRYPTION_TYPE_DETERMINISTIC: u8 = 1;
65
66pub const ENCRYPTION_TYPE_RANDOMIZED: u8 = 2;
68
69pub const NORMALIZATION_RULE_VERSION: u8 = 1;
71
72#[derive(Debug, Clone)]
77pub struct CekTableEntry {
78 pub database_id: u32,
80 pub cek_id: u32,
82 pub cek_version: u32,
84 pub cek_md_version: u64,
86 pub values: Vec<CekValue>,
88}
89
90#[derive(Debug, Clone)]
95pub struct CekValue {
96 pub encrypted_value: Bytes,
98 pub key_store_provider_name: String,
100 pub cmk_path: String,
102 pub encryption_algorithm: String,
104}
105
106#[derive(Debug, Clone)]
111pub struct CryptoMetadata {
112 pub cek_table_ordinal: u16,
114 pub algorithm_id: u8,
116 pub encryption_type: EncryptionTypeWire,
118 pub normalization_version: u8,
120}
121
122#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124#[non_exhaustive]
125pub enum EncryptionTypeWire {
126 Deterministic,
128 Randomized,
130}
131
132impl EncryptionTypeWire {
133 #[must_use]
135 pub fn from_u8(value: u8) -> Option<Self> {
136 match value {
137 ENCRYPTION_TYPE_DETERMINISTIC => Some(Self::Deterministic),
138 ENCRYPTION_TYPE_RANDOMIZED => Some(Self::Randomized),
139 _ => None,
140 }
141 }
142
143 #[must_use]
145 pub fn to_u8(self) -> u8 {
146 match self {
147 Self::Deterministic => ENCRYPTION_TYPE_DETERMINISTIC,
148 Self::Randomized => ENCRYPTION_TYPE_RANDOMIZED,
149 }
150 }
151}
152
153#[derive(Debug, Clone, Default)]
155pub struct CekTable {
156 pub entries: Vec<CekTableEntry>,
158}
159
160impl CekTable {
161 #[must_use]
163 pub fn new() -> Self {
164 Self::default()
165 }
166
167 #[must_use]
169 pub fn get(&self, ordinal: u16) -> Option<&CekTableEntry> {
170 self.entries.get(ordinal as usize)
171 }
172
173 #[must_use]
175 pub fn is_empty(&self) -> bool {
176 self.entries.is_empty()
177 }
178
179 #[must_use]
181 pub fn len(&self) -> usize {
182 self.entries.len()
183 }
184
185 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
210 if src.remaining() < 2 {
211 return Err(ProtocolError::UnexpectedEof);
212 }
213
214 let cek_count = src.get_u16_le() as usize;
215
216 let mut entries = Vec::with_capacity(cek_count);
217
218 for _ in 0..cek_count {
219 let entry = CekTableEntry::decode(src)?;
220 entries.push(entry);
221 }
222
223 Ok(Self { entries })
224 }
225}
226
227impl CekTableEntry {
228 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
230 if src.remaining() < 21 {
232 return Err(ProtocolError::UnexpectedEof);
233 }
234
235 let database_id = src.get_u32_le();
236 let cek_id = src.get_u32_le();
237 let cek_version = src.get_u32_le();
238 let cek_md_version = src.get_u64_le();
239 let value_count = src.get_u8() as usize;
240
241 let mut values = Vec::with_capacity(value_count);
242
243 for _ in 0..value_count {
244 let value = CekValue::decode(src)?;
245 values.push(value);
246 }
247
248 Ok(Self {
249 database_id,
250 cek_id,
251 cek_version,
252 cek_md_version,
253 values,
254 })
255 }
256
257 #[must_use]
259 pub fn primary_value(&self) -> Option<&CekValue> {
260 self.values.first()
261 }
262}
263
264impl CekValue {
265 pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
267 if src.remaining() < 2 {
269 return Err(ProtocolError::UnexpectedEof);
270 }
271
272 let encrypted_value_length = src.get_u16_le() as usize;
273
274 if src.remaining() < encrypted_value_length {
275 return Err(ProtocolError::UnexpectedEof);
276 }
277
278 let encrypted_value = src.copy_to_bytes(encrypted_value_length);
279
280 let key_store_provider_name = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
282
283 let cmk_path = read_us_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
285
286 let encryption_algorithm = read_b_varchar(src).ok_or(ProtocolError::UnexpectedEof)?;
288
289 Ok(Self {
290 encrypted_value,
291 key_store_provider_name,
292 cmk_path,
293 encryption_algorithm,
294 })
295 }
296}
297
298impl CryptoMetadata {
299 pub const SIZE: usize = 5; pub fn decode(src: &mut impl Buf) -> Result<Self, ProtocolError> {
304 if src.remaining() < Self::SIZE {
305 return Err(ProtocolError::UnexpectedEof);
306 }
307
308 let cek_table_ordinal = src.get_u16_le();
309 let algorithm_id = src.get_u8();
310 let encryption_type_byte = src.get_u8();
311 let normalization_version = src.get_u8();
312
313 let encryption_type = EncryptionTypeWire::from_u8(encryption_type_byte).ok_or(
314 ProtocolError::InvalidField {
315 field: "encryption_type",
316 value: encryption_type_byte as u32,
317 },
318 )?;
319
320 Ok(Self {
321 cek_table_ordinal,
322 algorithm_id,
323 encryption_type,
324 normalization_version,
325 })
326 }
327
328 #[must_use]
330 pub fn is_aead_aes_256(&self) -> bool {
331 self.algorithm_id == ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
332 }
333
334 #[must_use]
336 pub fn is_deterministic(&self) -> bool {
337 self.encryption_type == EncryptionTypeWire::Deterministic
338 }
339
340 #[must_use]
342 pub fn is_randomized(&self) -> bool {
343 self.encryption_type == EncryptionTypeWire::Randomized
344 }
345}
346
347#[derive(Debug, Clone, Default)]
352pub struct ColumnCryptoInfo {
353 pub crypto_metadata: Option<CryptoMetadata>,
355}
356
357impl ColumnCryptoInfo {
358 #[must_use]
360 pub fn unencrypted() -> Self {
361 Self {
362 crypto_metadata: None,
363 }
364 }
365
366 #[must_use]
368 pub fn encrypted(metadata: CryptoMetadata) -> Self {
369 Self {
370 crypto_metadata: Some(metadata),
371 }
372 }
373
374 #[must_use]
376 pub fn is_encrypted(&self) -> bool {
377 self.crypto_metadata.is_some()
378 }
379}
380
381#[must_use]
383pub fn is_column_encrypted(flags: u16) -> bool {
384 (flags & COLUMN_FLAG_ENCRYPTED) != 0
385}
386
387#[cfg(test)]
388#[allow(clippy::unwrap_used, clippy::expect_used)]
389mod tests {
390 use super::*;
391 use bytes::BytesMut;
392
393 #[test]
394 fn test_encryption_type_wire_roundtrip() {
395 assert_eq!(
396 EncryptionTypeWire::from_u8(1),
397 Some(EncryptionTypeWire::Deterministic)
398 );
399 assert_eq!(
400 EncryptionTypeWire::from_u8(2),
401 Some(EncryptionTypeWire::Randomized)
402 );
403 assert_eq!(EncryptionTypeWire::from_u8(0), None);
404 assert_eq!(EncryptionTypeWire::from_u8(99), None);
405
406 assert_eq!(EncryptionTypeWire::Deterministic.to_u8(), 1);
407 assert_eq!(EncryptionTypeWire::Randomized.to_u8(), 2);
408 }
409
410 #[test]
411 fn test_crypto_metadata_decode() {
412 let data = [
413 0x00, 0x00, 0x02, 0x01, 0x01, ];
418
419 let mut cursor: &[u8] = &data;
420 let metadata = CryptoMetadata::decode(&mut cursor).unwrap();
421
422 assert_eq!(metadata.cek_table_ordinal, 0);
423 assert_eq!(
424 metadata.algorithm_id,
425 ALGORITHM_AEAD_AES_256_CBC_HMAC_SHA256
426 );
427 assert_eq!(metadata.encryption_type, EncryptionTypeWire::Deterministic);
428 assert_eq!(metadata.normalization_version, 1);
429 assert!(metadata.is_aead_aes_256());
430 assert!(metadata.is_deterministic());
431 assert!(!metadata.is_randomized());
432 }
433
434 #[test]
435 fn test_cek_value_decode() {
436 let mut data = BytesMut::new();
437
438 data.extend_from_slice(&[0x04, 0x00]);
440 data.extend_from_slice(&[0xDE, 0xAD, 0xBE, 0xEF]);
442 data.extend_from_slice(&[0x04]); data.extend_from_slice(&[b'T', 0x00, b'E', 0x00, b'S', 0x00, b'T', 0x00]);
445 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[b'k', 0x00, b'e', 0x00, b'y', 0x00, b'1', 0x00]);
448 data.extend_from_slice(&[0x03]); data.extend_from_slice(&[b'R', 0x00, b'S', 0x00, b'A', 0x00]);
451
452 let mut cursor: &[u8] = &data;
453 let value = CekValue::decode(&mut cursor).unwrap();
454
455 assert_eq!(value.encrypted_value.as_ref(), &[0xDE, 0xAD, 0xBE, 0xEF]);
456 assert_eq!(value.key_store_provider_name, "TEST");
457 assert_eq!(value.cmk_path, "key1");
458 assert_eq!(value.encryption_algorithm, "RSA");
459 }
460
461 #[test]
462 fn test_cek_table_entry_decode() {
463 let mut data = BytesMut::new();
464
465 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
467 data.extend_from_slice(&[0x02, 0x00, 0x00, 0x00]);
469 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]);
471 data.extend_from_slice(&[0x64, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]);
473 data.extend_from_slice(&[0x01]);
475
476 data.extend_from_slice(&[0x04, 0x00]); data.extend_from_slice(&[0x11, 0x22, 0x33, 0x44]); data.extend_from_slice(&[0x02]); data.extend_from_slice(&[b'K', 0x00, b'S', 0x00]); data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]); let mut cursor: &[u8] = &data;
487 let entry = CekTableEntry::decode(&mut cursor).expect("should decode entry");
488
489 assert_eq!(entry.database_id, 1);
490 assert_eq!(entry.cek_id, 2);
491 assert_eq!(entry.cek_version, 1);
492 assert_eq!(entry.cek_md_version, 100);
493 assert_eq!(entry.values.len(), 1);
494
495 let value = entry.primary_value().expect("should have primary value");
496 assert_eq!(value.encrypted_value.as_ref(), &[0x11, 0x22, 0x33, 0x44]);
497 }
498
499 #[test]
500 fn test_cek_table_decode() {
501 let mut data = BytesMut::new();
502
503 data.extend_from_slice(&[0x01, 0x00]);
505
506 data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[0x02, 0x00]); data.extend_from_slice(&[0xAB, 0xCD]); data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'K', 0x00]);
518 data.extend_from_slice(&[0x01, 0x00]); data.extend_from_slice(&[b'P', 0x00]);
520 data.extend_from_slice(&[0x01]); data.extend_from_slice(&[b'A', 0x00]);
522
523 let mut cursor: &[u8] = &data;
524 let table = CekTable::decode(&mut cursor).expect("should decode table");
525
526 assert_eq!(table.len(), 1);
527 assert!(!table.is_empty());
528
529 let entry = table.get(0).expect("should have first entry");
530 assert_eq!(entry.database_id, 1);
531 }
532
533 #[test]
534 fn test_is_column_encrypted() {
535 assert!(!is_column_encrypted(0x0000));
536 assert!(!is_column_encrypted(0x0001)); assert!(is_column_encrypted(0x0800)); assert!(is_column_encrypted(0x0801)); }
540
541 #[test]
542 fn test_column_crypto_info() {
543 let unencrypted = ColumnCryptoInfo::unencrypted();
544 assert!(!unencrypted.is_encrypted());
545
546 let metadata = CryptoMetadata {
547 cek_table_ordinal: 0,
548 algorithm_id: 2,
549 encryption_type: EncryptionTypeWire::Randomized,
550 normalization_version: 1,
551 };
552 let encrypted = ColumnCryptoInfo::encrypted(metadata);
553 assert!(encrypted.is_encrypted());
554 }
555}