1extern crate alloc;
46
47use alloc::string::{String, ToString};
48use alloc::vec::Vec;
49
50use crate::error::{SecurityError, SecurityErrorKind, SecurityResult};
51use crate::token::DataHolder;
52
53pub const TOPIC_STATELESS_MESSAGE: &str = "DCPSParticipantStatelessMessage";
55
56pub const TOPIC_VOLATILE_MESSAGE_SECURE: &str = "DCPSParticipantVolatileMessageSecure";
58
59pub const TYPE_NAME_GENERIC_MESSAGE: &str = "ParticipantGenericMessage";
61
62pub mod class_id {
64 pub const AUTH_REQUEST: &str = "dds.sec.auth_request";
66 pub const AUTH: &str = "dds.sec.auth";
70 pub const PARTICIPANT_CRYPTO_TOKENS: &str = "dds.sec.participant_crypto_tokens";
72 pub const DATAWRITER_CRYPTO_TOKENS: &str = "dds.sec.datawriter_crypto_tokens";
74 pub const DATAREADER_CRYPTO_TOKENS: &str = "dds.sec.datareader_crypto_tokens";
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Default)]
80pub struct MessageIdentity {
81 pub source_guid: [u8; 16],
83 pub sequence_number: i64,
85}
86
87impl MessageIdentity {
88 #[must_use]
90 pub fn is_nil(&self) -> bool {
91 self.source_guid == [0; 16] && self.sequence_number == 0
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Default)]
97pub struct ParticipantGenericMessage {
98 pub message_identity: MessageIdentity,
100 pub related_message_identity: MessageIdentity,
103 pub destination_participant_key: [u8; 16],
106 pub destination_endpoint_key: [u8; 16],
108 pub source_endpoint_key: [u8; 16],
110 pub message_class_id: String,
112 pub message_data: Vec<DataHolder>,
115}
116
117const MAX_GENERIC_MESSAGE_BYTES: usize = 256 * 1024;
119
120const MAX_MESSAGE_DATA_LEN: u32 = 64;
122
123const MAX_CLASS_ID_LEN: u32 = 256;
125
126impl ParticipantGenericMessage {
127 #[must_use]
131 pub fn to_cdr_le(&self) -> Vec<u8> {
132 let mut out = Vec::with_capacity(128);
133 encode_message_identity(&mut out, &self.message_identity, true);
134 encode_message_identity(&mut out, &self.related_message_identity, true);
135 out.extend_from_slice(&self.destination_participant_key);
136 out.extend_from_slice(&self.destination_endpoint_key);
137 out.extend_from_slice(&self.source_endpoint_key);
138 encode_string(&mut out, &self.message_class_id, true);
139 encode_u32(&mut out, self.message_data.len() as u32, true);
140 for dh in &self.message_data {
141 align(&mut out, 4);
148 out.extend_from_slice(&dh.to_cdr_le());
149 }
150 out
151 }
152
153 pub fn from_cdr_le(bytes: &[u8]) -> SecurityResult<Self> {
159 if bytes.len() > MAX_GENERIC_MESSAGE_BYTES {
160 return Err(SecurityError::new(
161 SecurityErrorKind::BadArgument,
162 "generic_message: payload exceeds DoS cap",
163 ));
164 }
165 let mut cur = Cursor::new(bytes, true);
166 let message_identity = decode_message_identity(&mut cur)?;
167 let related_message_identity = decode_message_identity(&mut cur)?;
168 let destination_participant_key = cur.read_array16()?;
169 let destination_endpoint_key = cur.read_array16()?;
170 let source_endpoint_key = cur.read_array16()?;
171 let message_class_id = decode_string(&mut cur)?;
172 if message_class_id.len() > MAX_CLASS_ID_LEN as usize {
173 return Err(SecurityError::new(
174 SecurityErrorKind::BadArgument,
175 "generic_message: message_class_id exceeds 256 bytes",
176 ));
177 }
178 let count = cur.read_u32()?;
179 if count > MAX_MESSAGE_DATA_LEN {
180 return Err(SecurityError::new(
181 SecurityErrorKind::BadArgument,
182 "generic_message: message_data sequence too long",
183 ));
184 }
185 let mut message_data = Vec::with_capacity(count as usize);
186 for _ in 0..count {
187 cur.align(4)?;
190 let (dh, consumed) = DataHolder::from_cdr_le_consumed(&cur.buf[cur.pos..])?;
191 cur.advance(consumed)?;
192 message_data.push(dh);
193 }
194 Ok(Self {
195 message_identity,
196 related_message_identity,
197 destination_participant_key,
198 destination_endpoint_key,
199 source_endpoint_key,
200 message_class_id,
201 message_data,
202 })
203 }
204}
205
206fn align(buf: &mut Vec<u8>, n: usize) {
211 let pad = (n - buf.len() % n) % n;
212 for _ in 0..pad {
213 buf.push(0);
214 }
215}
216
217fn encode_u32(buf: &mut Vec<u8>, v: u32, le: bool) {
218 align(buf, 4);
219 if le {
220 buf.extend_from_slice(&v.to_le_bytes());
221 } else {
222 buf.extend_from_slice(&v.to_be_bytes());
223 }
224}
225
226fn encode_i64(buf: &mut Vec<u8>, v: i64, le: bool) {
227 align(buf, 8);
228 if le {
229 buf.extend_from_slice(&v.to_le_bytes());
230 } else {
231 buf.extend_from_slice(&v.to_be_bytes());
232 }
233}
234
235fn encode_string(buf: &mut Vec<u8>, s: &str, le: bool) {
236 let bytes = s.as_bytes();
237 let len = (bytes.len() + 1) as u32;
238 encode_u32(buf, len, le);
239 buf.extend_from_slice(bytes);
240 buf.push(0);
241}
242
243fn encode_message_identity(buf: &mut Vec<u8>, mi: &MessageIdentity, le: bool) {
244 buf.extend_from_slice(&mi.source_guid);
245 encode_i64(buf, mi.sequence_number, le);
246}
247
248struct Cursor<'a> {
249 buf: &'a [u8],
250 pos: usize,
251 le: bool,
252}
253
254impl<'a> Cursor<'a> {
255 fn new(buf: &'a [u8], le: bool) -> Self {
256 Self { buf, pos: 0, le }
257 }
258
259 fn align(&mut self, n: usize) -> SecurityResult<()> {
260 let pad = (n - self.pos % n) % n;
261 self.advance(pad)
262 }
263
264 fn advance(&mut self, n: usize) -> SecurityResult<()> {
265 if self.pos.saturating_add(n) > self.buf.len() {
266 return Err(SecurityError::new(
267 SecurityErrorKind::BadArgument,
268 "generic_message: truncated",
269 ));
270 }
271 self.pos += n;
272 Ok(())
273 }
274
275 fn read_u32(&mut self) -> SecurityResult<u32> {
276 self.align(4)?;
277 let start = self.pos;
278 self.advance(4)?;
279 let mut a = [0u8; 4];
280 a.copy_from_slice(&self.buf[start..start + 4]);
281 Ok(if self.le {
282 u32::from_le_bytes(a)
283 } else {
284 u32::from_be_bytes(a)
285 })
286 }
287
288 fn read_i64(&mut self) -> SecurityResult<i64> {
289 self.align(8)?;
290 let start = self.pos;
291 self.advance(8)?;
292 let mut a = [0u8; 8];
293 a.copy_from_slice(&self.buf[start..start + 8]);
294 Ok(if self.le {
295 i64::from_le_bytes(a)
296 } else {
297 i64::from_be_bytes(a)
298 })
299 }
300
301 fn read_array16(&mut self) -> SecurityResult<[u8; 16]> {
302 let start = self.pos;
303 self.advance(16)?;
304 let mut a = [0u8; 16];
305 a.copy_from_slice(&self.buf[start..start + 16]);
306 Ok(a)
307 }
308
309 fn read_slice(&mut self, n: usize) -> SecurityResult<&'a [u8]> {
310 let start = self.pos;
311 self.advance(n)?;
312 Ok(&self.buf[start..start + n])
313 }
314}
315
316fn decode_message_identity(cur: &mut Cursor<'_>) -> SecurityResult<MessageIdentity> {
317 let source_guid = cur.read_array16()?;
318 let sequence_number = cur.read_i64()?;
319 Ok(MessageIdentity {
320 source_guid,
321 sequence_number,
322 })
323}
324
325fn decode_string(cur: &mut Cursor<'_>) -> SecurityResult<String> {
326 let len = cur.read_u32()? as usize;
327 if len == 0 {
328 return Err(SecurityError::new(
329 SecurityErrorKind::BadArgument,
330 "generic_message: zero-length string (no NUL)",
331 ));
332 }
333 if len > MAX_CLASS_ID_LEN as usize + 1 {
334 return Err(SecurityError::new(
335 SecurityErrorKind::BadArgument,
336 "generic_message: string > cap",
337 ));
338 }
339 let raw = cur.read_slice(len)?;
340 if raw[len - 1] != 0 {
341 return Err(SecurityError::new(
342 SecurityErrorKind::BadArgument,
343 "generic_message: missing terminating NUL",
344 ));
345 }
346 let s = core::str::from_utf8(&raw[..len - 1]).map_err(|_| {
347 SecurityError::new(SecurityErrorKind::BadArgument, "generic_message: non-utf8")
348 })?;
349 Ok(s.to_string())
350}
351
352#[cfg(test)]
353#[allow(clippy::expect_used, clippy::unwrap_used)]
354mod tests {
355 use super::*;
356
357 fn sample_msg() -> ParticipantGenericMessage {
358 ParticipantGenericMessage {
359 message_identity: MessageIdentity {
360 source_guid: [0xAA; 16],
361 sequence_number: 42,
362 },
363 related_message_identity: MessageIdentity::default(),
364 destination_participant_key: [0xBB; 16],
365 destination_endpoint_key: [0; 16],
366 source_endpoint_key: [0xCC; 16],
367 message_class_id: class_id::AUTH_REQUEST.to_string(),
368 message_data: vec![DataHolder::new("DDS:Auth:PKI-DH:1.2+AuthReq")],
369 }
370 }
371
372 #[test]
373 fn roundtrip_le() {
374 let msg = sample_msg();
375 let bytes = msg.to_cdr_le();
376 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
377 assert_eq!(msg, back);
378 }
379
380 #[test]
381 fn message_data_dataholder_is_inline_not_length_prefixed() {
382 let msg = sample_msg();
390 let bytes = msg.to_cdr_le();
391 let dh_inline = msg.message_data[0].to_cdr_le();
392 assert!(
393 bytes.ends_with(&dh_inline),
394 "the DataHolder must stand INLINE at the end"
395 );
396 let pos = bytes.len() - dh_inline.len();
397 let prefix = u32::from_le_bytes([
398 bytes[pos - 4],
399 bytes[pos - 3],
400 bytes[pos - 2],
401 bytes[pos - 1],
402 ]);
403 assert_eq!(
404 prefix, 1,
405 "before the DataHolder stands the sequence count (=1)"
406 );
407 assert_ne!(
408 prefix as usize,
409 dh_inline.len(),
410 "NO octet-seq length prefix before the DataHolder"
411 );
412 }
413
414 #[test]
415 fn nil_message_identity() {
416 let mi = MessageIdentity::default();
417 assert!(mi.is_nil());
418 let mi2 = MessageIdentity {
419 source_guid: [0xAA; 16],
420 sequence_number: 0,
421 };
422 assert!(!mi2.is_nil());
423 }
424
425 #[test]
426 fn class_id_constants_match_spec() {
427 assert_eq!(class_id::AUTH_REQUEST, "dds.sec.auth_request");
430 assert_eq!(class_id::AUTH, "dds.sec.auth");
431 assert_eq!(
432 class_id::PARTICIPANT_CRYPTO_TOKENS,
433 "dds.sec.participant_crypto_tokens"
434 );
435 assert_eq!(
436 class_id::DATAWRITER_CRYPTO_TOKENS,
437 "dds.sec.datawriter_crypto_tokens"
438 );
439 assert_eq!(
440 class_id::DATAREADER_CRYPTO_TOKENS,
441 "dds.sec.datareader_crypto_tokens"
442 );
443 }
444
445 #[test]
446 fn topic_name_constants_match_spec() {
447 assert_eq!(TOPIC_STATELESS_MESSAGE, "DCPSParticipantStatelessMessage");
448 assert_eq!(
449 TOPIC_VOLATILE_MESSAGE_SECURE,
450 "DCPSParticipantVolatileMessageSecure"
451 );
452 assert_eq!(TYPE_NAME_GENERIC_MESSAGE, "ParticipantGenericMessage");
453 }
454
455 #[test]
456 fn empty_message_data_roundtrip() {
457 let msg = ParticipantGenericMessage {
458 message_class_id: class_id::AUTH.to_string(),
459 ..ParticipantGenericMessage::default()
460 };
461 let bytes = msg.to_cdr_le();
462 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
463 assert_eq!(msg, back);
464 assert!(back.message_data.is_empty());
465 }
466
467 #[test]
468 fn handshake_request_token_in_message_data() {
469 let token = DataHolder::new("DDS:Auth:PKI-DH:1.2+AuthReq")
472 .with_property("c.dsign_algo", "ECDSA-SHA256")
473 .with_binary_property("c.id", vec![0x30, 0x82, 0x01, 0x23]);
474 let msg = ParticipantGenericMessage {
475 message_identity: MessageIdentity {
476 source_guid: [0xAA; 16],
477 sequence_number: 1,
478 },
479 destination_participant_key: [0xBB; 16],
480 source_endpoint_key: [0xCC; 16],
481 message_class_id: class_id::AUTH_REQUEST.to_string(),
482 message_data: vec![token],
483 ..ParticipantGenericMessage::default()
484 };
485 let bytes = msg.to_cdr_le();
486 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
487 assert_eq!(back.message_data.len(), 1);
488 assert_eq!(back.message_data[0].class_id, "DDS:Auth:PKI-DH:1.2+AuthReq");
489 assert_eq!(
490 back.message_data[0].property("c.dsign_algo"),
491 Some("ECDSA-SHA256")
492 );
493 assert_eq!(
494 back.message_data[0].binary_property("c.id"),
495 Some(&[0x30, 0x82, 0x01, 0x23][..])
496 );
497 }
498
499 #[test]
500 fn related_message_identity_links_reply_to_request() {
501 let request_id = MessageIdentity {
504 source_guid: [0xAA; 16],
505 sequence_number: 1,
506 };
507 let reply = ParticipantGenericMessage {
508 message_identity: MessageIdentity {
509 source_guid: [0xDD; 16],
510 sequence_number: 1,
511 },
512 related_message_identity: request_id.clone(),
513 destination_participant_key: [0xAA; 16],
514 source_endpoint_key: [0xDD; 16],
515 message_class_id: class_id::AUTH.to_string(),
516 ..ParticipantGenericMessage::default()
517 };
518 let bytes = reply.to_cdr_le();
519 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
520 assert_eq!(back.related_message_identity, request_id);
521 }
522
523 #[test]
524 fn truncated_buffer_rejected() {
525 let msg = sample_msg();
526 let bytes = msg.to_cdr_le();
527 let truncated = &bytes[..bytes.len() / 2];
528 assert!(ParticipantGenericMessage::from_cdr_le(truncated).is_err());
529 }
530
531 #[test]
532 fn invalid_class_id_utf8_rejected() {
533 let mut buf = Vec::new();
535 buf.extend_from_slice(&[0u8; 16]);
537 buf.extend_from_slice(&0i64.to_le_bytes());
538 buf.extend_from_slice(&[0u8; 16]);
540 buf.extend_from_slice(&0i64.to_le_bytes());
541 buf.extend_from_slice(&[0u8; 48]);
543 buf.extend_from_slice(&5u32.to_le_bytes());
545 buf.extend_from_slice(&[0xFF, 0xFE, 0xFD, 0xFC, 0x00]);
546 align(&mut buf, 4);
548 buf.extend_from_slice(&0u32.to_le_bytes());
549 let err = ParticipantGenericMessage::from_cdr_le(&buf).unwrap_err();
550 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
551 }
552
553 #[test]
554 fn dos_cap_total_payload_rejected() {
555 let big = vec![0u8; MAX_GENERIC_MESSAGE_BYTES + 1];
556 let err = ParticipantGenericMessage::from_cdr_le(&big).unwrap_err();
557 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
558 }
559
560 #[test]
561 fn message_data_cap_rejected() {
562 let mut buf = Vec::new();
564 buf.extend_from_slice(&[0u8; 24]); buf.extend_from_slice(&[0u8; 24]); buf.extend_from_slice(&[0u8; 48]); buf.extend_from_slice(&1u32.to_le_bytes());
569 buf.push(0);
570 align(&mut buf, 4);
571 buf.extend_from_slice(&1_000_000u32.to_le_bytes());
572 let err = ParticipantGenericMessage::from_cdr_le(&buf).unwrap_err();
573 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
574 }
575}