zerodds_security_runtime/
gate.rs1use alloc::string::String;
7use alloc::vec::Vec;
8
9use zerodds_security::authentication::{IdentityHandle, SharedSecretHandle};
10use zerodds_security::crypto::{CryptoHandle, CryptographicPlugin};
11use zerodds_security::error::SecurityError;
12use zerodds_security_permissions::{Governance, ProtectionKind};
13use zerodds_security_rtps::{
14 RTPS_HEADER_LEN, SEC_PREFIX, SRTPS_PREFIX, SecurityRtpsError, decode_secured_rtps_message,
15 decode_secured_submessage, encode_secured_rtps_message, encode_secured_submessage,
16};
17
18#[derive(Debug)]
20pub enum SecurityGateError {
21 CryptoSetup(SecurityError),
23 Wrapper(SecurityRtpsError),
25 Crypto(SecurityError),
27 PolicyViolation(String),
30}
31
32impl core::fmt::Display for SecurityGateError {
33 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
34 match self {
35 Self::CryptoSetup(e) => write!(f, "security-gate setup: {e}"),
36 Self::Wrapper(e) => write!(f, "security-gate wrapper: {e}"),
37 Self::Crypto(e) => write!(f, "security-gate crypto: {e}"),
38 Self::PolicyViolation(m) => write!(f, "security-gate policy: {m}"),
39 }
40 }
41}
42
43#[cfg(feature = "std")]
44impl std::error::Error for SecurityGateError {}
45
46impl From<SecurityRtpsError> for SecurityGateError {
47 fn from(e: SecurityRtpsError) -> Self {
48 Self::Wrapper(e)
49 }
50}
51
52pub struct SecurityGate<'c, P: CryptographicPlugin> {
55 domain_id: u32,
56 governance: Governance,
57 crypto: &'c mut P,
58 local: Option<CryptoHandle>,
61}
62
63impl<'c, P: CryptographicPlugin> SecurityGate<'c, P> {
64 pub fn new(domain_id: u32, governance: Governance, crypto: &'c mut P) -> Self {
66 Self {
67 domain_id,
68 governance,
69 crypto,
70 local: None,
71 }
72 }
73
74 fn ensure_local(&mut self) -> Result<CryptoHandle, SecurityGateError> {
77 if let Some(h) = self.local {
78 return Ok(h);
79 }
80 let h = self
81 .crypto
82 .register_local_participant(IdentityHandle(1), &[])
83 .map_err(SecurityGateError::CryptoSetup)?;
84 self.local = Some(h);
85 Ok(h)
86 }
87
88 #[must_use]
91 pub fn outbound_protection(&self, topic_name: &str) -> ProtectionKind {
92 self.governance
93 .find_topic_rule(self.domain_id, topic_name)
94 .map(|r| r.data_protection_kind)
95 .unwrap_or(ProtectionKind::None)
96 }
97
98 pub fn encode_outbound(
105 &mut self,
106 topic_name: &str,
107 plaintext: &[u8],
108 ) -> Result<Vec<u8>, SecurityGateError> {
109 let kind = self.outbound_protection(topic_name);
110 match kind {
111 ProtectionKind::None => Ok(plaintext.to_vec()),
112 _ => {
113 let local = self.ensure_local()?;
114 let wrapped = encode_secured_submessage(self.crypto, local, &[], plaintext)?;
115 Ok(wrapped)
116 }
117 }
118 }
119
120 pub fn decode_inbound(
132 &mut self,
133 topic_name: &str,
134 wire: &[u8],
135 ) -> Result<Vec<u8>, SecurityGateError> {
136 let kind = self.outbound_protection(topic_name);
137 let looks_secured = !wire.is_empty() && wire[0] == SEC_PREFIX;
138 match (kind, looks_secured) {
139 (ProtectionKind::None, false) => Ok(wire.to_vec()),
140 (_, true) => {
141 let local = self.ensure_local()?;
142 decode_secured_submessage(self.crypto, local, local, wire)
143 .map_err(SecurityGateError::from)
144 }
145 (_, false) => Err(SecurityGateError::PolicyViolation(alloc::format!(
146 "topic '{topic_name}' verlangt {kind:?}, bekam plain-submessage"
147 ))),
148 }
149 }
150
151 pub fn register_remote(
158 &mut self,
159 remote_identity: IdentityHandle,
160 shared_secret: SharedSecretHandle,
161 ) -> Result<CryptoHandle, SecurityGateError> {
162 let local = self.ensure_local()?;
163 self.crypto
164 .register_matched_remote_participant(local, remote_identity, shared_secret)
165 .map_err(SecurityGateError::CryptoSetup)
166 }
167
168 pub fn local_token(&mut self) -> Result<Vec<u8>, SecurityGateError> {
174 let local = self.ensure_local()?;
175 self.crypto
176 .create_local_participant_crypto_tokens(local, CryptoHandle(0))
177 .map_err(SecurityGateError::Crypto)
178 }
179
180 pub fn set_remote_token(
186 &mut self,
187 remote: CryptoHandle,
188 token: &[u8],
189 ) -> Result<(), SecurityGateError> {
190 let local = self.ensure_local()?;
191 self.crypto
192 .set_remote_participant_crypto_tokens(local, remote, token)
193 .map_err(SecurityGateError::Crypto)
194 }
195
196 #[must_use]
200 pub fn message_protection(&self) -> ProtectionKind {
201 self.governance
202 .find_domain_rule(self.domain_id)
203 .map(|r| r.rtps_protection_kind)
204 .unwrap_or(ProtectionKind::None)
205 }
206
207 pub fn encode_outbound_message(
213 &mut self,
214 message: &[u8],
215 ) -> Result<Vec<u8>, SecurityGateError> {
216 match self.message_protection() {
217 ProtectionKind::None => Ok(message.to_vec()),
218 _ => {
219 let local = self.ensure_local()?;
220 encode_secured_rtps_message(self.crypto, local, &[], message)
221 .map_err(SecurityGateError::from)
222 }
223 }
224 }
225
226 pub fn decode_inbound_message(
238 &mut self,
239 remote_slot: CryptoHandle,
240 wire: &[u8],
241 ) -> Result<Vec<u8>, SecurityGateError> {
242 let looks_secured = wire.len() > RTPS_HEADER_LEN && wire[RTPS_HEADER_LEN] == SRTPS_PREFIX;
243 let kind = self.message_protection();
244 match (kind, looks_secured) {
245 (ProtectionKind::None, false) => Ok(wire.to_vec()),
246 (_, true) => {
247 decode_secured_rtps_message(self.crypto, remote_slot, remote_slot, wire)
249 .map_err(SecurityGateError::from)
250 }
251 (_, false) => Err(SecurityGateError::PolicyViolation(alloc::format!(
252 "domain {} verlangt {kind:?}, bekam plain-rtps-message",
253 self.domain_id
254 ))),
255 }
256 }
257}
258
259#[cfg(test)]
260#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
261mod tests {
262 use super::*;
263 use zerodds_security_crypto::AesGcmCryptoPlugin;
264 use zerodds_security_permissions::parse_governance_xml;
265
266 const GOV: &str = r#"
267<domain_access_rules>
268 <domain_rule>
269 <domains><id>0</id></domains>
270 <topic_access_rules>
271 <topic_rule>
272 <topic_expression>Secret*</topic_expression>
273 <data_protection_kind>ENCRYPT</data_protection_kind>
274 </topic_rule>
275 <topic_rule>
276 <topic_expression>*</topic_expression>
277 <data_protection_kind>NONE</data_protection_kind>
278 </topic_rule>
279 </topic_access_rules>
280 </domain_rule>
281</domain_access_rules>
282"#;
283
284 #[test]
285 fn outbound_protection_reads_governance_topic_rule() {
286 let gov = parse_governance_xml(GOV).unwrap();
287 let mut crypto = AesGcmCryptoPlugin::new();
288 let gate = SecurityGate::new(0, gov, &mut crypto);
289 assert_eq!(
290 gate.outbound_protection("SecretRecipe"),
291 ProtectionKind::Encrypt
292 );
293 assert_eq!(gate.outbound_protection("Chatter"), ProtectionKind::None);
294 }
295
296 #[test]
297 fn encode_none_is_passthrough_byte_identical() {
298 let gov = parse_governance_xml(GOV).unwrap();
299 let mut crypto = AesGcmCryptoPlugin::new();
300 let mut gate = SecurityGate::new(0, gov, &mut crypto);
301 let plain = b"plaintext submessage";
302 let wire = gate.encode_outbound("Chatter", plain).unwrap();
303 assert_eq!(wire, plain);
304 }
305
306 #[test]
307 fn encode_encrypt_wraps_in_sec_prefix() {
308 let gov = parse_governance_xml(GOV).unwrap();
309 let mut crypto = AesGcmCryptoPlugin::new();
310 let mut gate = SecurityGate::new(0, gov, &mut crypto);
311 let wire = gate.encode_outbound("SecretOrder", b"top-secret").unwrap();
312 assert_eq!(wire[0], SEC_PREFIX, "must begin with SEC_PREFIX");
313 assert!(
314 !wire.windows(10).any(|w| w == b"top-secret"),
315 "plaintext sollte nicht im wire sein"
316 );
317 }
318
319 #[test]
320 fn encode_decode_roundtrip_via_gate() {
321 let gov = parse_governance_xml(GOV).unwrap();
322 let mut crypto = AesGcmCryptoPlugin::new();
323 let mut gate = SecurityGate::new(0, gov, &mut crypto);
324 let wire = gate.encode_outbound("SecretOrder", b"hello").unwrap();
325 let back = gate.decode_inbound("SecretOrder", &wire).unwrap();
326 assert_eq!(back, b"hello");
327 }
328
329 #[test]
330 fn inbound_plain_on_protected_topic_is_policy_violation() {
331 let gov = parse_governance_xml(GOV).unwrap();
332 let mut crypto = AesGcmCryptoPlugin::new();
333 let mut gate = SecurityGate::new(0, gov, &mut crypto);
334 let err = gate
336 .decode_inbound("SecretOrder", b"plaintext-leak")
337 .unwrap_err();
338 assert!(matches!(err, SecurityGateError::PolicyViolation(_)));
339 }
340
341 #[test]
342 fn inbound_plain_on_unprotected_topic_passes_through() {
343 let gov = parse_governance_xml(GOV).unwrap();
344 let mut crypto = AesGcmCryptoPlugin::new();
345 let mut gate = SecurityGate::new(0, gov, &mut crypto);
346 let back = gate.decode_inbound("Chatter", b"plain-ok").unwrap();
347 assert_eq!(back, b"plain-ok");
348 }
349
350 #[test]
351 fn missing_domain_rule_defaults_to_none() {
352 let gov = parse_governance_xml(GOV).unwrap();
354 let mut crypto = AesGcmCryptoPlugin::new();
355 let gate = SecurityGate::new(99, gov, &mut crypto);
356 assert_eq!(
357 gate.outbound_protection("SecretOrder"),
358 ProtectionKind::None
359 );
360 }
361
362 const GOV_RTPS: &str = r#"
368<domain_access_rules>
369 <domain_rule>
370 <domains><id>0</id></domains>
371 <rtps_protection_kind>ENCRYPT</rtps_protection_kind>
372 <topic_access_rules>
373 <topic_rule><topic_expression>*</topic_expression></topic_rule>
374 </topic_access_rules>
375 </domain_rule>
376</domain_access_rules>
377"#;
378
379 fn fake_rtps_message(body: &[u8]) -> Vec<u8> {
380 let mut m = Vec::with_capacity(20 + body.len());
381 m.extend_from_slice(b"RTPS\x02\x05\x01\x02");
382 m.extend_from_slice(&[0u8; 12]);
383 m.extend_from_slice(body);
384 m
385 }
386
387 #[test]
388 fn message_protection_reads_domain_rule() {
389 let gov = parse_governance_xml(GOV_RTPS).unwrap();
390 let mut crypto = AesGcmCryptoPlugin::new();
391 let gate = SecurityGate::new(0, gov, &mut crypto);
392 assert_eq!(gate.message_protection(), ProtectionKind::Encrypt);
393 }
394
395 #[test]
396 fn message_encode_none_is_passthrough() {
397 let gov = parse_governance_xml(GOV).unwrap();
399 let mut crypto = AesGcmCryptoPlugin::new();
400 let mut gate = SecurityGate::new(0, gov, &mut crypto);
401 let msg = fake_rtps_message(b"plain");
402 let wire = gate.encode_outbound_message(&msg).unwrap();
403 assert_eq!(wire, msg);
404 }
405
406 #[test]
407 fn message_encode_encrypt_wraps_after_header() {
408 let gov = parse_governance_xml(GOV_RTPS).unwrap();
409 let mut crypto = AesGcmCryptoPlugin::new();
410 let mut gate = SecurityGate::new(0, gov, &mut crypto);
411 let msg = fake_rtps_message(b"[DATA][HEARTBEAT]");
412 let wire = gate.encode_outbound_message(&msg).unwrap();
413 assert_eq!(&wire[..4], b"RTPS");
414 assert_eq!(wire[20], SRTPS_PREFIX);
415 }
416
417 #[test]
418 fn message_policy_violation_on_plain_inbound() {
419 let gov = parse_governance_xml(GOV_RTPS).unwrap();
420 let mut crypto = AesGcmCryptoPlugin::new();
421 let mut gate = SecurityGate::new(0, gov, &mut crypto);
422 let plain = fake_rtps_message(b"nope");
424 let err = gate
425 .decode_inbound_message(CryptoHandle(1), &plain)
426 .unwrap_err();
427 assert!(matches!(err, SecurityGateError::PolicyViolation(_)));
428 }
429
430 #[test]
433 fn e2e_cross_participant_message_roundtrip() {
434 let gov1 = parse_governance_xml(GOV_RTPS).unwrap();
435 let gov2 = parse_governance_xml(GOV_RTPS).unwrap();
436 let mut alice_crypto = AesGcmCryptoPlugin::new();
437 let mut bob_crypto = AesGcmCryptoPlugin::new();
438
439 let mut alice = SecurityGate::new(0, gov1, &mut alice_crypto);
440 let mut bob = SecurityGate::new(0, gov2, &mut bob_crypto);
441
442 let alice_token = alice.local_token().unwrap();
445 let bob_token = bob.local_token().unwrap();
446
447 let alice_view_of_bob = alice
450 .register_remote(IdentityHandle(2), SharedSecretHandle(1))
451 .unwrap();
452 alice
453 .set_remote_token(alice_view_of_bob, &bob_token)
454 .unwrap();
455
456 let bob_view_of_alice = bob
457 .register_remote(IdentityHandle(1), SharedSecretHandle(1))
458 .unwrap();
459 bob.set_remote_token(bob_view_of_alice, &alice_token)
460 .unwrap();
461
462 let msg = fake_rtps_message(b"[DATA:cross-participant]");
464 let wire = alice.encode_outbound_message(&msg).unwrap();
465
466 let back = bob
468 .decode_inbound_message(bob_view_of_alice, &wire)
469 .unwrap();
470 assert_eq!(back, msg);
471 }
472}