1use alloc::collections::BTreeMap;
27use alloc::vec::Vec;
28
29use ring::hkdf;
30use zerodds_security::authentication::{IdentityHandle, SharedSecretHandle};
31use zerodds_security::crypto::{CryptoHandle, CryptographicPlugin, ReceiverMac};
32use zerodds_security::error::{SecurityError, SecurityErrorKind, SecurityResult};
33
34use crate::plugin::AesGcmCryptoPlugin;
35use crate::suite::Suite;
36
37pub const CLASS_ID_PSK_CRYPTO: &str = "DDS:Crypto:PSK:AES-GCM-GMAC:1.2";
39
40pub const HKDF_INFO_PSK_MASTER_KEY: &[u8] = b"DDS-Security-1.2-PSK-MasterKey";
43
44pub struct PskCryptoPlugin {
47 inner: AesGcmCryptoPlugin,
48 suite: Suite,
49 psks: BTreeMap<u64, Vec<u8>>,
54}
55
56impl PskCryptoPlugin {
57 #[must_use]
59 pub fn new() -> Self {
60 Self::with_suite(Suite::Aes128Gcm)
61 }
62
63 #[must_use]
65 pub fn with_suite(suite: Suite) -> Self {
66 Self {
67 inner: AesGcmCryptoPlugin::with_suite(suite),
68 suite,
69 psks: BTreeMap::new(),
70 }
71 }
72
73 #[must_use]
75 pub fn suite(&self) -> Suite {
76 self.suite
77 }
78
79 pub fn register_psk(&mut self, psk_id: u64, key: Vec<u8>) -> SecurityResult<()> {
88 if key.is_empty() {
89 return Err(SecurityError::new(
90 SecurityErrorKind::BadArgument,
91 "psk-crypto: pre-shared-key leer",
92 ));
93 }
94 self.psks.insert(psk_id, key);
95 Ok(())
96 }
97
98 pub fn register_psk_remote(
107 &mut self,
108 local: CryptoHandle,
109 remote_identity: IdentityHandle,
110 psk_id: u64,
111 session_id: [u8; 4],
112 ) -> SecurityResult<CryptoHandle> {
113 let psk = self
114 .psks
115 .get(&psk_id)
116 .ok_or_else(|| {
117 SecurityError::new(
118 SecurityErrorKind::BadArgument,
119 "psk-crypto: psk_id nicht registriert",
120 )
121 })?
122 .clone();
123 let master_key = derive_psk_master_key(self.suite, &psk, &session_id)?;
124 let master_salt = derive_psk_master_salt(&psk, &session_id)?;
125 let key_id = derive_psk_key_id(&psk, &session_id)?;
126
127 let mut token = Vec::with_capacity(1 + 4 + 4 + 32 + master_key.len());
131 token.push(self.suite.transform_kind_id());
132 token.extend_from_slice(&session_id);
133 token.extend_from_slice(&key_id);
134 token.extend_from_slice(&master_salt);
135 token.extend_from_slice(&master_key);
136
137 let slot = self.inner.register_matched_remote_participant(
141 local,
142 remote_identity,
143 SharedSecretHandle(0),
144 )?;
145 self.inner
146 .set_remote_participant_crypto_tokens(local, slot, &token)?;
147 Ok(slot)
148 }
149
150 pub fn register_psk_local(
158 &mut self,
159 psk_id: u64,
160 session_id: [u8; 4],
161 ) -> SecurityResult<CryptoHandle> {
162 let psk = self
163 .psks
164 .get(&psk_id)
165 .ok_or_else(|| {
166 SecurityError::new(
167 SecurityErrorKind::BadArgument,
168 "psk-crypto: psk_id nicht registriert",
169 )
170 })?
171 .clone();
172 let master_key = derive_psk_master_key(self.suite, &psk, &session_id)?;
173 let master_salt = derive_psk_master_salt(&psk, &session_id)?;
174 let key_id = derive_psk_key_id(&psk, &session_id)?;
175 let mut token = Vec::with_capacity(1 + 4 + 4 + 32 + master_key.len());
176 token.push(self.suite.transform_kind_id());
177 token.extend_from_slice(&session_id);
178 token.extend_from_slice(&key_id);
179 token.extend_from_slice(&master_salt);
180 token.extend_from_slice(&master_key);
181
182 let slot = self
183 .inner
184 .register_local_participant(IdentityHandle(0), &[])?;
185 self.inner
186 .set_remote_participant_crypto_tokens(slot, slot, &token)?;
187 Ok(slot)
188 }
189}
190
191impl Default for PskCryptoPlugin {
192 fn default() -> Self {
193 Self::new()
194 }
195}
196
197fn derive_psk_master_key(
201 suite: Suite,
202 psk: &[u8],
203 session_id: &[u8; 4],
204) -> SecurityResult<Vec<u8>> {
205 derive_psk_field(psk, session_id, HKDF_INFO_PSK_MASTER_KEY, suite.key_len())
206}
207
208const HKDF_INFO_PSK_MASTER_SALT: &[u8] = b"DDS-Security-1.2-PSK-MasterSalt";
212const HKDF_INFO_PSK_KEY_ID: &[u8] = b"DDS-Security-1.2-PSK-SenderKeyId";
213
214fn derive_psk_master_salt(psk: &[u8], session_id: &[u8; 4]) -> SecurityResult<[u8; 32]> {
215 let v = derive_psk_field(psk, session_id, HKDF_INFO_PSK_MASTER_SALT, 32)?;
216 let mut out = [0u8; 32];
217 out.copy_from_slice(&v);
218 Ok(out)
219}
220
221fn derive_psk_key_id(psk: &[u8], session_id: &[u8; 4]) -> SecurityResult<[u8; 4]> {
222 let v = derive_psk_field(psk, session_id, HKDF_INFO_PSK_KEY_ID, 4)?;
223 let mut out = [0u8; 4];
224 out.copy_from_slice(&v);
225 Ok(out)
226}
227
228fn derive_psk_field(
229 psk: &[u8],
230 session_id: &[u8; 4],
231 info: &[u8],
232 out_len: usize,
233) -> SecurityResult<Vec<u8>> {
234 if psk.is_empty() {
235 return Err(SecurityError::new(
236 SecurityErrorKind::BadArgument,
237 "psk-crypto: empty psk",
238 ));
239 }
240 let salt_obj = hkdf::Salt::new(hkdf::HKDF_SHA256, session_id);
241 let prk = salt_obj.extract(psk);
242 let info_arr = [info];
243 let okm = prk
244 .expand(
245 &info_arr,
246 HkdfLen {
247 len: out_len,
248 hmac: hkdf::HKDF_SHA256,
249 },
250 )
251 .map_err(|_| {
252 SecurityError::new(SecurityErrorKind::CryptoFailed, "psk-crypto: HKDF expand")
253 })?;
254 let mut out = alloc::vec![0u8; out_len];
255 okm.fill(&mut out).map_err(|_| {
256 SecurityError::new(SecurityErrorKind::CryptoFailed, "psk-crypto: HKDF fill")
257 })?;
258 Ok(out)
259}
260
261struct HkdfLen {
262 len: usize,
263 hmac: hkdf::Algorithm,
264}
265
266impl hkdf::KeyType for HkdfLen {
267 fn len(&self) -> usize {
268 self.len
269 }
270}
271
272impl From<HkdfLen> for hkdf::Algorithm {
273 fn from(v: HkdfLen) -> Self {
274 v.hmac
275 }
276}
277
278impl CryptographicPlugin for PskCryptoPlugin {
279 fn register_local_participant(
280 &mut self,
281 identity: IdentityHandle,
282 properties: &[(&str, &str)],
283 ) -> SecurityResult<CryptoHandle> {
284 self.inner.register_local_participant(identity, properties)
285 }
286
287 fn register_matched_remote_participant(
288 &mut self,
289 local: CryptoHandle,
290 remote_identity: IdentityHandle,
291 shared_secret: SharedSecretHandle,
292 ) -> SecurityResult<CryptoHandle> {
293 self.inner
294 .register_matched_remote_participant(local, remote_identity, shared_secret)
295 }
296
297 fn register_local_endpoint(
298 &mut self,
299 participant: CryptoHandle,
300 is_writer: bool,
301 properties: &[(&str, &str)],
302 ) -> SecurityResult<CryptoHandle> {
303 self.inner
304 .register_local_endpoint(participant, is_writer, properties)
305 }
306
307 fn create_local_participant_crypto_tokens(
308 &mut self,
309 local: CryptoHandle,
310 remote: CryptoHandle,
311 ) -> SecurityResult<Vec<u8>> {
312 self.inner
313 .create_local_participant_crypto_tokens(local, remote)
314 }
315
316 fn set_remote_participant_crypto_tokens(
317 &mut self,
318 local: CryptoHandle,
319 remote: CryptoHandle,
320 tokens: &[u8],
321 ) -> SecurityResult<()> {
322 self.inner
323 .set_remote_participant_crypto_tokens(local, remote, tokens)
324 }
325
326 fn encrypt_submessage(
327 &self,
328 local: CryptoHandle,
329 remote_list: &[CryptoHandle],
330 plaintext: &[u8],
331 aad_extension: &[u8],
332 ) -> SecurityResult<Vec<u8>> {
333 self.inner
334 .encrypt_submessage(local, remote_list, plaintext, aad_extension)
335 }
336
337 fn decrypt_submessage(
338 &self,
339 local: CryptoHandle,
340 remote: CryptoHandle,
341 ciphertext: &[u8],
342 aad_extension: &[u8],
343 ) -> SecurityResult<Vec<u8>> {
344 self.inner
345 .decrypt_submessage(local, remote, ciphertext, aad_extension)
346 }
347
348 fn encrypt_submessage_multi(
349 &self,
350 local: CryptoHandle,
351 receivers: &[(CryptoHandle, u32)],
352 plaintext: &[u8],
353 aad_extension: &[u8],
354 ) -> SecurityResult<(Vec<u8>, Vec<ReceiverMac>)> {
355 self.inner
356 .encrypt_submessage_multi(local, receivers, plaintext, aad_extension)
357 }
358
359 #[allow(clippy::too_many_arguments)]
360 fn decrypt_submessage_with_receiver_mac(
361 &self,
362 local: CryptoHandle,
363 remote: CryptoHandle,
364 own_key_id: u32,
365 own_mac_key_handle: CryptoHandle,
366 ciphertext: &[u8],
367 macs: &[ReceiverMac],
368 aad_extension: &[u8],
369 ) -> SecurityResult<Vec<u8>> {
370 self.inner.decrypt_submessage_with_receiver_mac(
371 local,
372 remote,
373 own_key_id,
374 own_mac_key_handle,
375 ciphertext,
376 macs,
377 aad_extension,
378 )
379 }
380
381 fn plugin_class_id(&self) -> &str {
382 CLASS_ID_PSK_CRYPTO
383 }
384}
385
386#[cfg(test)]
387#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
388mod tests {
389 use super::*;
390
391 #[test]
392 fn class_id_matches_spec() {
393 let p = PskCryptoPlugin::new();
394 assert_eq!(p.plugin_class_id(), "DDS:Crypto:PSK:AES-GCM-GMAC:1.2");
395 }
396
397 #[test]
398 fn transform_kind_id_aes128_matches_x509_path() {
399 let p = PskCryptoPlugin::with_suite(Suite::Aes128Gcm);
400 assert_eq!(p.suite().transform_kind_id(), 0x02);
401 }
402
403 #[test]
404 fn transform_kind_id_aes256_matches_x509_path() {
405 let p = PskCryptoPlugin::with_suite(Suite::Aes256Gcm);
406 assert_eq!(p.suite().transform_kind_id(), 0x04);
407 }
408
409 #[test]
410 fn psk_master_key_derivation_is_deterministic() {
411 let psk = alloc::vec![0xAB; 32];
412 let session = [0u8, 0, 0, 1];
413 let k1 = derive_psk_master_key(Suite::Aes128Gcm, &psk, &session).unwrap();
414 let k2 = derive_psk_master_key(Suite::Aes128Gcm, &psk, &session).unwrap();
415 assert_eq!(k1, k2);
416 assert_eq!(k1.len(), 16);
417 }
418
419 #[test]
420 fn psk_master_key_changes_with_session_id() {
421 let psk = alloc::vec![0xAB; 32];
422 let k1 = derive_psk_master_key(Suite::Aes128Gcm, &psk, &[0, 0, 0, 1]).unwrap();
423 let k2 = derive_psk_master_key(Suite::Aes128Gcm, &psk, &[0, 0, 0, 2]).unwrap();
424 assert_ne!(k1, k2);
425 }
426
427 #[test]
428 fn psk_master_key_rejects_empty_psk() {
429 let err = derive_psk_master_key(Suite::Aes128Gcm, &[], &[0u8; 4]).unwrap_err();
430 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
431 }
432
433 #[test]
434 fn register_psk_rejects_empty_key() {
435 let mut p = PskCryptoPlugin::new();
436 let err = p.register_psk(1, Vec::new()).unwrap_err();
437 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
438 }
439
440 #[test]
441 fn register_psk_remote_unknown_id_rejected() {
442 let mut p = PskCryptoPlugin::new();
443 let local = p
444 .register_local_participant(IdentityHandle(1), &[])
445 .unwrap();
446 let err = p
447 .register_psk_remote(local, IdentityHandle(2), 99, [0u8; 4])
448 .unwrap_err();
449 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
450 }
451
452 #[test]
453 fn psk_encrypt_decrypt_roundtrip_two_plugins_same_psk() {
454 let psk = alloc::vec![0x77u8; 32];
455 let mut alice = PskCryptoPlugin::new();
456 let mut bob = PskCryptoPlugin::new();
457 alice.register_psk(7, psk.clone()).unwrap();
458 bob.register_psk(7, psk).unwrap();
459
460 let session = [0u8, 0, 0, 42];
461 let alice_local = alice.register_psk_local(7, session).unwrap();
462 let bob_local = bob.register_psk_local(7, session).unwrap();
463 let alice_to_bob = alice
465 .register_psk_remote(alice_local, IdentityHandle(2), 7, session)
466 .unwrap();
467 let bob_to_alice = bob
468 .register_psk_remote(bob_local, IdentityHandle(1), 7, session)
469 .unwrap();
470
471 let plain = b"top-secret-psk-payload";
472 let wire = alice
473 .encrypt_submessage(alice_to_bob, &[], plain, &[])
474 .unwrap();
475 let back = bob
476 .decrypt_submessage(bob_to_alice, bob_to_alice, &wire, &[])
477 .unwrap();
478 assert_eq!(back, plain);
479 }
480}