1extern crate alloc;
46
47use alloc::string::{String, ToString};
48use alloc::vec::Vec;
49
50use crate::error::{SecurityError, SecurityErrorKind, SecurityResult};
51use crate::token::DataHolder;
52
53pub const TOPIC_STATELESS_MESSAGE: &str = "DCPSParticipantStatelessMessage";
55
56pub const TOPIC_VOLATILE_MESSAGE_SECURE: &str = "DCPSParticipantVolatileMessageSecure";
58
59pub const TYPE_NAME_GENERIC_MESSAGE: &str = "ParticipantGenericMessage";
61
62pub mod class_id {
64 pub const AUTH_REQUEST: &str = "dds.sec.auth_request";
66 pub const AUTH: &str = "dds.sec.auth";
70 pub const PARTICIPANT_CRYPTO_TOKENS: &str = "dds.sec.participant_crypto_tokens";
72 pub const DATAWRITER_CRYPTO_TOKENS: &str = "dds.sec.datawriter_crypto_tokens";
74 pub const DATAREADER_CRYPTO_TOKENS: &str = "dds.sec.datareader_crypto_tokens";
76}
77
78#[derive(Debug, Clone, PartialEq, Eq, Default)]
80pub struct MessageIdentity {
81 pub source_guid: [u8; 16],
83 pub sequence_number: i64,
85}
86
87impl MessageIdentity {
88 #[must_use]
90 pub fn is_nil(&self) -> bool {
91 self.source_guid == [0; 16] && self.sequence_number == 0
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Default)]
97pub struct ParticipantGenericMessage {
98 pub message_identity: MessageIdentity,
100 pub related_message_identity: MessageIdentity,
103 pub destination_participant_key: [u8; 16],
106 pub destination_endpoint_key: [u8; 16],
108 pub source_endpoint_key: [u8; 16],
110 pub message_class_id: String,
112 pub message_data: Vec<DataHolder>,
115}
116
117const MAX_GENERIC_MESSAGE_BYTES: usize = 256 * 1024;
119
120const MAX_MESSAGE_DATA_LEN: u32 = 64;
122
123const MAX_CLASS_ID_LEN: u32 = 256;
125
126impl ParticipantGenericMessage {
127 #[must_use]
131 pub fn to_cdr_le(&self) -> Vec<u8> {
132 let mut out = Vec::with_capacity(128);
133 encode_message_identity(&mut out, &self.message_identity, true);
134 encode_message_identity(&mut out, &self.related_message_identity, true);
135 out.extend_from_slice(&self.destination_participant_key);
136 out.extend_from_slice(&self.destination_endpoint_key);
137 out.extend_from_slice(&self.source_endpoint_key);
138 encode_string(&mut out, &self.message_class_id, true);
139 encode_u32(&mut out, self.message_data.len() as u32, true);
140 for dh in &self.message_data {
141 let dh_bytes = dh.to_cdr_le();
144 encode_octet_seq(&mut out, &dh_bytes, true);
145 }
146 out
147 }
148
149 pub fn from_cdr_le(bytes: &[u8]) -> SecurityResult<Self> {
155 if bytes.len() > MAX_GENERIC_MESSAGE_BYTES {
156 return Err(SecurityError::new(
157 SecurityErrorKind::BadArgument,
158 "generic_message: payload exceeds DoS cap",
159 ));
160 }
161 let mut cur = Cursor::new(bytes, true);
162 let message_identity = decode_message_identity(&mut cur)?;
163 let related_message_identity = decode_message_identity(&mut cur)?;
164 let destination_participant_key = cur.read_array16()?;
165 let destination_endpoint_key = cur.read_array16()?;
166 let source_endpoint_key = cur.read_array16()?;
167 let message_class_id = decode_string(&mut cur)?;
168 if message_class_id.len() > MAX_CLASS_ID_LEN as usize {
169 return Err(SecurityError::new(
170 SecurityErrorKind::BadArgument,
171 "generic_message: message_class_id exceeds 256 bytes",
172 ));
173 }
174 let count = cur.read_u32()?;
175 if count > MAX_MESSAGE_DATA_LEN {
176 return Err(SecurityError::new(
177 SecurityErrorKind::BadArgument,
178 "generic_message: message_data sequence too long",
179 ));
180 }
181 let mut message_data = Vec::with_capacity(count as usize);
182 for _ in 0..count {
183 let dh_bytes = decode_octet_seq(&mut cur)?;
184 let dh = DataHolder::from_cdr_le(&dh_bytes)?;
185 message_data.push(dh);
186 }
187 Ok(Self {
188 message_identity,
189 related_message_identity,
190 destination_participant_key,
191 destination_endpoint_key,
192 source_endpoint_key,
193 message_class_id,
194 message_data,
195 })
196 }
197}
198
199fn align(buf: &mut Vec<u8>, n: usize) {
204 let pad = (n - buf.len() % n) % n;
205 for _ in 0..pad {
206 buf.push(0);
207 }
208}
209
210fn encode_u32(buf: &mut Vec<u8>, v: u32, le: bool) {
211 align(buf, 4);
212 if le {
213 buf.extend_from_slice(&v.to_le_bytes());
214 } else {
215 buf.extend_from_slice(&v.to_be_bytes());
216 }
217}
218
219fn encode_i64(buf: &mut Vec<u8>, v: i64, le: bool) {
220 align(buf, 8);
221 if le {
222 buf.extend_from_slice(&v.to_le_bytes());
223 } else {
224 buf.extend_from_slice(&v.to_be_bytes());
225 }
226}
227
228fn encode_string(buf: &mut Vec<u8>, s: &str, le: bool) {
229 let bytes = s.as_bytes();
230 let len = (bytes.len() + 1) as u32;
231 encode_u32(buf, len, le);
232 buf.extend_from_slice(bytes);
233 buf.push(0);
234}
235
236fn encode_octet_seq(buf: &mut Vec<u8>, v: &[u8], le: bool) {
237 encode_u32(buf, v.len() as u32, le);
238 buf.extend_from_slice(v);
239}
240
241fn encode_message_identity(buf: &mut Vec<u8>, mi: &MessageIdentity, le: bool) {
242 buf.extend_from_slice(&mi.source_guid);
243 encode_i64(buf, mi.sequence_number, le);
244}
245
246struct Cursor<'a> {
247 buf: &'a [u8],
248 pos: usize,
249 le: bool,
250}
251
252impl<'a> Cursor<'a> {
253 fn new(buf: &'a [u8], le: bool) -> Self {
254 Self { buf, pos: 0, le }
255 }
256
257 fn align(&mut self, n: usize) -> SecurityResult<()> {
258 let pad = (n - self.pos % n) % n;
259 self.advance(pad)
260 }
261
262 fn advance(&mut self, n: usize) -> SecurityResult<()> {
263 if self.pos.saturating_add(n) > self.buf.len() {
264 return Err(SecurityError::new(
265 SecurityErrorKind::BadArgument,
266 "generic_message: truncated",
267 ));
268 }
269 self.pos += n;
270 Ok(())
271 }
272
273 fn read_u32(&mut self) -> SecurityResult<u32> {
274 self.align(4)?;
275 let start = self.pos;
276 self.advance(4)?;
277 let mut a = [0u8; 4];
278 a.copy_from_slice(&self.buf[start..start + 4]);
279 Ok(if self.le {
280 u32::from_le_bytes(a)
281 } else {
282 u32::from_be_bytes(a)
283 })
284 }
285
286 fn read_i64(&mut self) -> SecurityResult<i64> {
287 self.align(8)?;
288 let start = self.pos;
289 self.advance(8)?;
290 let mut a = [0u8; 8];
291 a.copy_from_slice(&self.buf[start..start + 8]);
292 Ok(if self.le {
293 i64::from_le_bytes(a)
294 } else {
295 i64::from_be_bytes(a)
296 })
297 }
298
299 fn read_array16(&mut self) -> SecurityResult<[u8; 16]> {
300 let start = self.pos;
301 self.advance(16)?;
302 let mut a = [0u8; 16];
303 a.copy_from_slice(&self.buf[start..start + 16]);
304 Ok(a)
305 }
306
307 fn read_slice(&mut self, n: usize) -> SecurityResult<&'a [u8]> {
308 let start = self.pos;
309 self.advance(n)?;
310 Ok(&self.buf[start..start + n])
311 }
312}
313
314fn decode_message_identity(cur: &mut Cursor<'_>) -> SecurityResult<MessageIdentity> {
315 let source_guid = cur.read_array16()?;
316 let sequence_number = cur.read_i64()?;
317 Ok(MessageIdentity {
318 source_guid,
319 sequence_number,
320 })
321}
322
323fn decode_string(cur: &mut Cursor<'_>) -> SecurityResult<String> {
324 let len = cur.read_u32()? as usize;
325 if len == 0 {
326 return Err(SecurityError::new(
327 SecurityErrorKind::BadArgument,
328 "generic_message: zero-length string (no NUL)",
329 ));
330 }
331 if len > MAX_CLASS_ID_LEN as usize + 1 {
332 return Err(SecurityError::new(
333 SecurityErrorKind::BadArgument,
334 "generic_message: string > cap",
335 ));
336 }
337 let raw = cur.read_slice(len)?;
338 if raw[len - 1] != 0 {
339 return Err(SecurityError::new(
340 SecurityErrorKind::BadArgument,
341 "generic_message: missing terminating NUL",
342 ));
343 }
344 let s = core::str::from_utf8(&raw[..len - 1]).map_err(|_| {
345 SecurityError::new(SecurityErrorKind::BadArgument, "generic_message: non-utf8")
346 })?;
347 Ok(s.to_string())
348}
349
350fn decode_octet_seq(cur: &mut Cursor<'_>) -> SecurityResult<Vec<u8>> {
351 let len = cur.read_u32()? as usize;
352 if len > MAX_GENERIC_MESSAGE_BYTES {
353 return Err(SecurityError::new(
354 SecurityErrorKind::BadArgument,
355 "generic_message: octet_seq > cap",
356 ));
357 }
358 Ok(cur.read_slice(len)?.to_vec())
359}
360
361#[cfg(test)]
362#[allow(clippy::expect_used, clippy::unwrap_used)]
363mod tests {
364 use super::*;
365
366 fn sample_msg() -> ParticipantGenericMessage {
367 ParticipantGenericMessage {
368 message_identity: MessageIdentity {
369 source_guid: [0xAA; 16],
370 sequence_number: 42,
371 },
372 related_message_identity: MessageIdentity::default(),
373 destination_participant_key: [0xBB; 16],
374 destination_endpoint_key: [0; 16],
375 source_endpoint_key: [0xCC; 16],
376 message_class_id: class_id::AUTH_REQUEST.to_string(),
377 message_data: vec![DataHolder::new("DDS:Auth:PKI-DH:1.2+AuthReq")],
378 }
379 }
380
381 #[test]
382 fn roundtrip_le() {
383 let msg = sample_msg();
384 let bytes = msg.to_cdr_le();
385 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
386 assert_eq!(msg, back);
387 }
388
389 #[test]
390 fn nil_message_identity() {
391 let mi = MessageIdentity::default();
392 assert!(mi.is_nil());
393 let mi2 = MessageIdentity {
394 source_guid: [0xAA; 16],
395 sequence_number: 0,
396 };
397 assert!(!mi2.is_nil());
398 }
399
400 #[test]
401 fn class_id_constants_match_spec() {
402 assert_eq!(class_id::AUTH_REQUEST, "dds.sec.auth_request");
405 assert_eq!(class_id::AUTH, "dds.sec.auth");
406 assert_eq!(
407 class_id::PARTICIPANT_CRYPTO_TOKENS,
408 "dds.sec.participant_crypto_tokens"
409 );
410 assert_eq!(
411 class_id::DATAWRITER_CRYPTO_TOKENS,
412 "dds.sec.datawriter_crypto_tokens"
413 );
414 assert_eq!(
415 class_id::DATAREADER_CRYPTO_TOKENS,
416 "dds.sec.datareader_crypto_tokens"
417 );
418 }
419
420 #[test]
421 fn topic_name_constants_match_spec() {
422 assert_eq!(TOPIC_STATELESS_MESSAGE, "DCPSParticipantStatelessMessage");
423 assert_eq!(
424 TOPIC_VOLATILE_MESSAGE_SECURE,
425 "DCPSParticipantVolatileMessageSecure"
426 );
427 assert_eq!(TYPE_NAME_GENERIC_MESSAGE, "ParticipantGenericMessage");
428 }
429
430 #[test]
431 fn empty_message_data_roundtrip() {
432 let msg = ParticipantGenericMessage {
433 message_class_id: class_id::AUTH.to_string(),
434 ..ParticipantGenericMessage::default()
435 };
436 let bytes = msg.to_cdr_le();
437 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
438 assert_eq!(msg, back);
439 assert!(back.message_data.is_empty());
440 }
441
442 #[test]
443 fn handshake_request_token_in_message_data() {
444 let token = DataHolder::new("DDS:Auth:PKI-DH:1.2+AuthReq")
447 .with_property("c.dsign_algo", "ECDSA-SHA256")
448 .with_binary_property("c.id", vec![0x30, 0x82, 0x01, 0x23]);
449 let msg = ParticipantGenericMessage {
450 message_identity: MessageIdentity {
451 source_guid: [0xAA; 16],
452 sequence_number: 1,
453 },
454 destination_participant_key: [0xBB; 16],
455 source_endpoint_key: [0xCC; 16],
456 message_class_id: class_id::AUTH_REQUEST.to_string(),
457 message_data: vec![token],
458 ..ParticipantGenericMessage::default()
459 };
460 let bytes = msg.to_cdr_le();
461 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
462 assert_eq!(back.message_data.len(), 1);
463 assert_eq!(back.message_data[0].class_id, "DDS:Auth:PKI-DH:1.2+AuthReq");
464 assert_eq!(
465 back.message_data[0].property("c.dsign_algo"),
466 Some("ECDSA-SHA256")
467 );
468 assert_eq!(
469 back.message_data[0].binary_property("c.id"),
470 Some(&[0x30, 0x82, 0x01, 0x23][..])
471 );
472 }
473
474 #[test]
475 fn related_message_identity_links_reply_to_request() {
476 let request_id = MessageIdentity {
480 source_guid: [0xAA; 16],
481 sequence_number: 1,
482 };
483 let reply = ParticipantGenericMessage {
484 message_identity: MessageIdentity {
485 source_guid: [0xDD; 16],
486 sequence_number: 1,
487 },
488 related_message_identity: request_id.clone(),
489 destination_participant_key: [0xAA; 16],
490 source_endpoint_key: [0xDD; 16],
491 message_class_id: class_id::AUTH.to_string(),
492 ..ParticipantGenericMessage::default()
493 };
494 let bytes = reply.to_cdr_le();
495 let back = ParticipantGenericMessage::from_cdr_le(&bytes).unwrap();
496 assert_eq!(back.related_message_identity, request_id);
497 }
498
499 #[test]
500 fn truncated_buffer_rejected() {
501 let msg = sample_msg();
502 let bytes = msg.to_cdr_le();
503 let truncated = &bytes[..bytes.len() / 2];
504 assert!(ParticipantGenericMessage::from_cdr_le(truncated).is_err());
505 }
506
507 #[test]
508 fn invalid_class_id_utf8_rejected() {
509 let mut buf = Vec::new();
511 buf.extend_from_slice(&[0u8; 16]);
513 buf.extend_from_slice(&0i64.to_le_bytes());
514 buf.extend_from_slice(&[0u8; 16]);
516 buf.extend_from_slice(&0i64.to_le_bytes());
517 buf.extend_from_slice(&[0u8; 48]);
519 buf.extend_from_slice(&5u32.to_le_bytes());
521 buf.extend_from_slice(&[0xFF, 0xFE, 0xFD, 0xFC, 0x00]);
522 align(&mut buf, 4);
524 buf.extend_from_slice(&0u32.to_le_bytes());
525 let err = ParticipantGenericMessage::from_cdr_le(&buf).unwrap_err();
526 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
527 }
528
529 #[test]
530 fn dos_cap_total_payload_rejected() {
531 let big = vec![0u8; MAX_GENERIC_MESSAGE_BYTES + 1];
532 let err = ParticipantGenericMessage::from_cdr_le(&big).unwrap_err();
533 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
534 }
535
536 #[test]
537 fn message_data_cap_rejected() {
538 let mut buf = Vec::new();
540 buf.extend_from_slice(&[0u8; 24]); buf.extend_from_slice(&[0u8; 24]); buf.extend_from_slice(&[0u8; 48]); buf.extend_from_slice(&1u32.to_le_bytes());
545 buf.push(0);
546 align(&mut buf, 4);
547 buf.extend_from_slice(&1_000_000u32.to_le_bytes());
548 let err = ParticipantGenericMessage::from_cdr_le(&buf).unwrap_err();
549 assert_eq!(err.kind, SecurityErrorKind::BadArgument);
550 }
551}