1use alloc::string::{String, ToString};
32use alloc::vec::Vec;
33
34use zerodds_rtps::property_list::{WireProperty, WirePropertyList};
35use zerodds_security_pki::DelegationChain;
36
37use crate::caps::PeerCapabilities;
38use crate::policy::{ProtectionLevel, SuiteHint};
39
40pub const KEY_AUTH_PLUGIN: &str = "dds.sec.auth.plugin_class";
46pub const KEY_ACCESS_PLUGIN: &str = "dds.sec.access.plugin_class";
48pub const KEY_CRYPTO_PLUGIN: &str = "dds.sec.crypto.plugin_class";
50pub const KEY_SUPPORTED_SUITES: &str = "zerodds.sec.supported_suites";
52pub const KEY_OFFERED_PROTECTION: &str = "zerodds.sec.offered_protection";
54pub const KEY_VENDOR_HINT: &str = "zerodds.sec.vendor_hint";
56pub const KEY_DELEGATION_CHAIN: &str = "zerodds.sec.delegation_chain";
60
61pub const MAX_DELEGATION_CHAIN_BYTES: usize = 8 * 1024;
63
64fn suite_to_str(s: SuiteHint) -> &'static str {
69 match s {
70 SuiteHint::Aes128Gcm => "AES_128_GCM",
71 SuiteHint::Aes256Gcm => "AES_256_GCM",
72 SuiteHint::HmacSha256 => "HMAC_SHA256",
73 }
74}
75
76fn suite_from_str(s: &str) -> Option<SuiteHint> {
77 match s.trim() {
78 "AES_128_GCM" => Some(SuiteHint::Aes128Gcm),
79 "AES_256_GCM" => Some(SuiteHint::Aes256Gcm),
80 "HMAC_SHA256" => Some(SuiteHint::HmacSha256),
81 _ => None,
82 }
83}
84
85fn suites_to_csv(suites: &[SuiteHint]) -> String {
86 let mut out = String::new();
87 for (i, s) in suites.iter().enumerate() {
88 if i > 0 {
89 out.push(',');
90 }
91 out.push_str(suite_to_str(*s));
92 }
93 out
94}
95
96fn suites_from_csv(csv: &str) -> Vec<SuiteHint> {
97 csv.split(',').filter_map(suite_from_str).collect()
98}
99
100fn protection_to_str(p: ProtectionLevel) -> &'static str {
105 match p {
106 ProtectionLevel::None => "NONE",
107 ProtectionLevel::Sign => "SIGN",
108 ProtectionLevel::Encrypt => "ENCRYPT",
109 }
110}
111
112fn protection_from_str(s: &str) -> Option<ProtectionLevel> {
113 match s.trim() {
114 "NONE" => Some(ProtectionLevel::None),
115 "SIGN" => Some(ProtectionLevel::Sign),
116 "ENCRYPT" => Some(ProtectionLevel::Encrypt),
117 _ => None,
118 }
119}
120
121pub fn advertise_security_caps(list: &mut WirePropertyList, caps: &PeerCapabilities) {
132 set_or_remove(list, KEY_AUTH_PLUGIN, caps.auth_plugin_class.as_deref());
133 set_or_remove(list, KEY_ACCESS_PLUGIN, caps.access_plugin_class.as_deref());
134 set_or_remove(list, KEY_CRYPTO_PLUGIN, caps.crypto_plugin_class.as_deref());
135 if !caps.supported_suites.is_empty() {
136 set_value(
137 list,
138 KEY_SUPPORTED_SUITES,
139 &suites_to_csv(&caps.supported_suites),
140 );
141 } else {
142 remove_by_key(list, KEY_SUPPORTED_SUITES);
143 }
144 set_value(
145 list,
146 KEY_OFFERED_PROTECTION,
147 protection_to_str(caps.offered_protection),
148 );
149 set_or_remove(list, KEY_VENDOR_HINT, caps.vendor_hint.as_deref());
150 if let Some(chain) = &caps.delegation_chain {
152 let raw = chain.encode();
153 if raw.len() <= MAX_DELEGATION_CHAIN_BYTES {
154 let b64 = base64_encode(&raw);
155 set_value(list, KEY_DELEGATION_CHAIN, &b64);
156 } else {
157 remove_by_key(list, KEY_DELEGATION_CHAIN);
160 }
161 } else {
162 remove_by_key(list, KEY_DELEGATION_CHAIN);
163 }
164}
165
166#[must_use]
171pub fn parse_peer_caps(list: &WirePropertyList) -> PeerCapabilities {
172 let offered_protection = list
173 .get(KEY_OFFERED_PROTECTION)
174 .and_then(protection_from_str)
175 .unwrap_or(ProtectionLevel::None);
176 let supported_suites = list
177 .get(KEY_SUPPORTED_SUITES)
178 .map(suites_from_csv)
179 .unwrap_or_default();
180 let delegation_chain = list
181 .get(KEY_DELEGATION_CHAIN)
182 .and_then(|s| {
183 if s.len() > MAX_DELEGATION_CHAIN_BYTES * 4 / 3 + 4 {
185 return None;
186 }
187 base64_decode(s).ok()
188 })
189 .filter(|raw| raw.len() <= MAX_DELEGATION_CHAIN_BYTES)
190 .and_then(|raw| DelegationChain::decode(&raw).ok());
191 PeerCapabilities {
192 auth_plugin_class: list.get(KEY_AUTH_PLUGIN).map(str::to_string),
193 access_plugin_class: list.get(KEY_ACCESS_PLUGIN).map(str::to_string),
194 crypto_plugin_class: list.get(KEY_CRYPTO_PLUGIN).map(str::to_string),
195 supported_suites,
196 offered_protection,
197 has_valid_cert: false,
198 validity_window: None,
199 vendor_hint: list.get(KEY_VENDOR_HINT).map(str::to_string),
200 cert_cn: None,
204 delegation_chain,
205 }
206}
207
208const B64_ALPHABET: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
217
218fn base64_encode(input: &[u8]) -> String {
219 let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
220 let mut chunks = input.chunks_exact(3);
221 for chunk in &mut chunks {
222 let n = (u32::from(chunk[0]) << 16) | (u32::from(chunk[1]) << 8) | u32::from(chunk[2]);
223 out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char);
224 out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char);
225 out.push(B64_ALPHABET[((n >> 6) & 0x3F) as usize] as char);
226 out.push(B64_ALPHABET[(n & 0x3F) as usize] as char);
227 }
228 let rem = chunks.remainder();
229 match rem.len() {
230 1 => {
231 let n = u32::from(rem[0]) << 16;
232 out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char);
233 out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char);
234 out.push('=');
235 out.push('=');
236 }
237 2 => {
238 let n = (u32::from(rem[0]) << 16) | (u32::from(rem[1]) << 8);
239 out.push(B64_ALPHABET[((n >> 18) & 0x3F) as usize] as char);
240 out.push(B64_ALPHABET[((n >> 12) & 0x3F) as usize] as char);
241 out.push(B64_ALPHABET[((n >> 6) & 0x3F) as usize] as char);
242 out.push('=');
243 }
244 _ => {}
245 }
246 out
247}
248
249fn base64_char_to_val(c: u8) -> Option<u8> {
250 match c {
251 b'A'..=b'Z' => Some(c - b'A'),
252 b'a'..=b'z' => Some(c - b'a' + 26),
253 b'0'..=b'9' => Some(c - b'0' + 52),
254 b'+' => Some(62),
255 b'/' => Some(63),
256 _ => None,
257 }
258}
259
260fn base64_decode(input: &str) -> Result<Vec<u8>, ()> {
261 let bytes = input.as_bytes();
262 if bytes.len() % 4 != 0 {
263 return Err(());
264 }
265 let mut out = Vec::with_capacity(bytes.len() / 4 * 3);
266 for chunk in bytes.chunks_exact(4) {
267 let mut vals = [0u8; 4];
268 let mut pad = 0usize;
269 for (i, &c) in chunk.iter().enumerate() {
270 if c == b'=' {
271 pad += 1;
272 vals[i] = 0;
273 } else if pad > 0 {
274 return Err(());
275 } else {
276 vals[i] = base64_char_to_val(c).ok_or(())?;
277 }
278 }
279 let n = (u32::from(vals[0]) << 18)
280 | (u32::from(vals[1]) << 12)
281 | (u32::from(vals[2]) << 6)
282 | u32::from(vals[3]);
283 out.push(((n >> 16) & 0xFF) as u8);
284 if pad < 2 {
285 out.push(((n >> 8) & 0xFF) as u8);
286 }
287 if pad < 1 {
288 out.push((n & 0xFF) as u8);
289 }
290 }
291 Ok(out)
292}
293
294fn set_value(list: &mut WirePropertyList, key: &str, value: &str) {
300 remove_by_key(list, key);
301 list.push(WireProperty::new(key.to_string(), value.to_string()));
302}
303
304fn set_or_remove(list: &mut WirePropertyList, key: &str, value: Option<&str>) {
306 match value {
307 Some(v) => set_value(list, key, v),
308 None => remove_by_key(list, key),
309 }
310}
311
312fn remove_by_key(list: &mut WirePropertyList, key: &str) {
313 list.entries.retain(|e| e.name != key);
314}
315
316#[cfg(test)]
321#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
322mod tests {
323 use super::*;
324 use crate::caps::Validity;
325
326 fn secure_caps() -> PeerCapabilities {
327 PeerCapabilities {
328 auth_plugin_class: Some("DDS:Auth:PKI-DH:1.2".to_string()),
329 access_plugin_class: Some("DDS:Access:Permissions:1.2".to_string()),
330 crypto_plugin_class: Some("DDS:Crypto:AES-GCM-GMAC:1.2".to_string()),
331 supported_suites: alloc::vec![SuiteHint::Aes128Gcm, SuiteHint::Aes256Gcm],
332 offered_protection: ProtectionLevel::Encrypt,
333 has_valid_cert: true, validity_window: Some(Validity {
335 not_before: 0,
336 not_after: 100,
337 }), vendor_hint: Some("zerodds".to_string()),
339 cert_cn: None, delegation_chain: None,
341 }
342 }
343
344 #[test]
347 fn suite_csv_roundtrip() {
348 let suites = alloc::vec![
349 SuiteHint::Aes128Gcm,
350 SuiteHint::Aes256Gcm,
351 SuiteHint::HmacSha256,
352 ];
353 let csv = suites_to_csv(&suites);
354 assert_eq!(csv, "AES_128_GCM,AES_256_GCM,HMAC_SHA256");
355 assert_eq!(suites_from_csv(&csv), suites);
356 }
357
358 #[test]
359 fn suite_csv_empty() {
360 assert_eq!(suites_to_csv(&[]), "");
361 assert_eq!(suites_from_csv(""), Vec::<SuiteHint>::new());
362 }
363
364 #[test]
365 fn suite_csv_ignores_unknown_tokens() {
366 let parsed = suites_from_csv("AES_128_GCM,FUTURE_SUITE,HMAC_SHA256");
367 assert_eq!(
368 parsed,
369 alloc::vec![SuiteHint::Aes128Gcm, SuiteHint::HmacSha256]
370 );
371 }
372
373 #[test]
374 fn suite_csv_trims_whitespace() {
375 let parsed = suites_from_csv(" AES_128_GCM , AES_256_GCM ");
376 assert_eq!(
377 parsed,
378 alloc::vec![SuiteHint::Aes128Gcm, SuiteHint::Aes256Gcm]
379 );
380 }
381
382 #[test]
385 fn protection_string_roundtrip_all_levels() {
386 for lvl in [
387 ProtectionLevel::None,
388 ProtectionLevel::Sign,
389 ProtectionLevel::Encrypt,
390 ] {
391 assert_eq!(protection_from_str(protection_to_str(lvl)), Some(lvl));
392 }
393 }
394
395 #[test]
396 fn protection_from_str_unknown_is_none() {
397 assert!(protection_from_str("WEIRD").is_none());
398 }
399
400 #[test]
403 fn roundtrip_preserves_wire_fields() {
404 let caps = secure_caps();
405 let mut list = WirePropertyList::new();
406 advertise_security_caps(&mut list, &caps);
407 let parsed = parse_peer_caps(&list);
408
409 assert_eq!(parsed.auth_plugin_class, caps.auth_plugin_class);
410 assert_eq!(parsed.access_plugin_class, caps.access_plugin_class);
411 assert_eq!(parsed.crypto_plugin_class, caps.crypto_plugin_class);
412 assert_eq!(parsed.supported_suites, caps.supported_suites);
413 assert_eq!(parsed.offered_protection, caps.offered_protection);
414 assert_eq!(parsed.vendor_hint, caps.vendor_hint);
415 }
416
417 #[test]
418 fn roundtrip_drops_non_wire_fields() {
419 let caps = secure_caps();
422 let mut list = WirePropertyList::new();
423 advertise_security_caps(&mut list, &caps);
424 let parsed = parse_peer_caps(&list);
425
426 assert!(!parsed.has_valid_cert);
427 assert!(parsed.validity_window.is_none());
428 }
429
430 #[test]
431 fn legacy_peer_without_security_properties_parses_as_empty() {
432 let list = WirePropertyList::new();
433 let parsed = parse_peer_caps(&list);
434
435 assert!(parsed.auth_plugin_class.is_none());
436 assert!(parsed.crypto_plugin_class.is_none());
437 assert!(parsed.access_plugin_class.is_none());
438 assert!(parsed.supported_suites.is_empty());
439 assert_eq!(parsed.offered_protection, ProtectionLevel::None);
440 assert!(parsed.vendor_hint.is_none());
441 }
442
443 #[test]
444 fn advertise_overwrites_existing_keys() {
445 let mut list = WirePropertyList::new();
446 list.push(WireProperty::new(KEY_OFFERED_PROTECTION, "SIGN"));
447 list.push(WireProperty::new(KEY_AUTH_PLUGIN, "stale-value"));
448
449 advertise_security_caps(
450 &mut list,
451 &PeerCapabilities {
452 auth_plugin_class: Some("DDS:Auth:PKI-DH:1.2".to_string()),
453 offered_protection: ProtectionLevel::Encrypt,
454 ..Default::default()
455 },
456 );
457 assert_eq!(list.get(KEY_OFFERED_PROTECTION), Some("ENCRYPT"));
458 assert_eq!(list.get(KEY_AUTH_PLUGIN), Some("DDS:Auth:PKI-DH:1.2"));
459 }
460
461 #[test]
462 fn advertise_keeps_foreign_properties_intact() {
463 let mut list = WirePropertyList::new();
464 list.push(WireProperty::new("foreign.key", "keep-me"));
465 advertise_security_caps(&mut list, &secure_caps());
466 assert_eq!(list.get("foreign.key"), Some("keep-me"));
467 }
468
469 #[test]
470 fn advertise_removes_keys_when_caps_field_is_none() {
471 let mut list = WirePropertyList::new();
474 list.push(WireProperty::new(KEY_AUTH_PLUGIN, "DDS:Auth:PKI-DH:1.2"));
475 advertise_security_caps(
476 &mut list,
477 &PeerCapabilities {
478 auth_plugin_class: None,
479 ..Default::default()
480 },
481 );
482 assert!(list.get(KEY_AUTH_PLUGIN).is_none());
483 }
484
485 #[test]
486 fn advertise_is_idempotent() {
487 let caps = secure_caps();
488 let mut list1 = WirePropertyList::new();
489 let mut list2 = WirePropertyList::new();
490 advertise_security_caps(&mut list1, &caps);
491 advertise_security_caps(&mut list2, &caps);
492 advertise_security_caps(&mut list2, &caps);
493 assert_eq!(list1, list2);
494 }
495
496 #[test]
497 fn parse_malformed_protection_falls_back_to_none() {
498 let list =
499 WirePropertyList::new().with(WireProperty::new(KEY_OFFERED_PROTECTION, "MAXIMAL"));
500 let parsed = parse_peer_caps(&list);
501 assert_eq!(parsed.offered_protection, ProtectionLevel::None);
502 }
503
504 #[test]
505 fn parse_malformed_suite_csv_drops_invalid_tokens() {
506 let list = WirePropertyList::new()
507 .with(WireProperty::new(KEY_SUPPORTED_SUITES, "AES_128_GCM,BOGUS"));
508 let parsed = parse_peer_caps(&list);
509 assert_eq!(parsed.supported_suites, alloc::vec![SuiteHint::Aes128Gcm]);
510 }
511
512 #[test]
513 fn advertise_with_no_suites_omits_suites_key() {
514 let caps = PeerCapabilities {
515 offered_protection: ProtectionLevel::Sign,
516 ..Default::default()
517 };
518 let mut list = WirePropertyList::new();
519 advertise_security_caps(&mut list, &caps);
520 assert!(list.get(KEY_SUPPORTED_SUITES).is_none());
521 }
522
523 #[test]
526 fn unknown_foreign_properties_dont_affect_parse() {
527 let list = WirePropertyList::new()
528 .with(WireProperty::new("com.rti.dds.Priority", "9"))
529 .with(WireProperty::new("org.eprosima.fastdds.type", "X"))
530 .with(WireProperty::new(KEY_OFFERED_PROTECTION, "SIGN"));
531 let parsed = parse_peer_caps(&list);
532 assert_eq!(parsed.offered_protection, ProtectionLevel::Sign);
533 }
534}
535
536#[cfg(test)]
537#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
538mod base64_and_delegation_tests {
539 use super::*;
540
541 #[test]
542 fn base64_encode_known_vectors() {
543 assert_eq!(base64_encode(b""), "");
544 assert_eq!(base64_encode(b"f"), "Zg==");
545 assert_eq!(base64_encode(b"fo"), "Zm8=");
546 assert_eq!(base64_encode(b"foo"), "Zm9v");
547 assert_eq!(base64_encode(b"foob"), "Zm9vYg==");
548 assert_eq!(base64_encode(b"fooba"), "Zm9vYmE=");
549 assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
550 }
551
552 #[test]
553 fn base64_decode_known_vectors() {
554 assert_eq!(base64_decode("").unwrap(), b"");
555 assert_eq!(base64_decode("Zg==").unwrap(), b"f");
556 assert_eq!(base64_decode("Zm8=").unwrap(), b"fo");
557 assert_eq!(base64_decode("Zm9v").unwrap(), b"foo");
558 assert_eq!(base64_decode("Zm9vYg==").unwrap(), b"foob");
559 assert_eq!(base64_decode("Zm9vYmE=").unwrap(), b"fooba");
560 assert_eq!(base64_decode("Zm9vYmFy").unwrap(), b"foobar");
561 }
562
563 #[test]
564 fn base64_decode_rejects_bad_length() {
565 assert!(base64_decode("ABC").is_err()); assert!(base64_decode("A").is_err());
567 }
568
569 #[test]
570 fn base64_decode_rejects_bad_chars() {
571 assert!(base64_decode("AB!?").is_err());
572 assert!(base64_decode("@@@@").is_err());
573 }
574
575 #[test]
576 fn base64_roundtrip_random_bytes() {
577 let blob: alloc::vec::Vec<u8> = (0..255u8).collect();
578 let encoded = base64_encode(&blob);
579 let decoded = base64_decode(&encoded).unwrap();
580 assert_eq!(decoded, blob);
581 }
582
583 #[test]
584 fn parse_skips_oversize_base64_property() {
585 let mut list = WirePropertyList::new();
586 let huge = "A".repeat(MAX_DELEGATION_CHAIN_BYTES * 4 / 3 + 100);
588 list.push(WireProperty::new(KEY_DELEGATION_CHAIN, huge.as_str()));
589 let parsed = parse_peer_caps(&list);
590 assert!(parsed.delegation_chain.is_none());
591 }
592}