1use crate::byte_queue::ByteQueue;
41use core::convert::TryFrom;
42use crc::{Crc, Digest, CRC_32_ISO_HDLC};
43
44#[derive(Debug)]
47pub struct MsgQueue<'a, const LEN: usize> {
48 byte_queue: ByteQueue,
49 prefix: &'a [u8],
50 rx_buf: [u8; LEN],
51 rx_buf_len: usize,
54 has_received_full_msg: bool,
55}
56
57use core::fmt;
58
59#[derive(Debug, PartialEq)]
60pub enum MqError {
61 MqFull,
62 MqEmpty,
63 MqCrcErr,
64 MqMsgTooBig,
65 MqWrongProtocolVersion,
66}
67
68impl fmt::Display for MqError {
69 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
70 match self {
71 MqError::MqFull => write!(f, "Message queue is full"),
72 MqError::MqEmpty => write!(f, "Message queue is empty"),
73 MqError::MqCrcErr => write!(f, "CRC check failed"),
74 MqError::MqMsgTooBig => write!(f, "Message is too big"),
75 MqError::MqWrongProtocolVersion => {
76 write!(f, "Message protocol version is incompatible")
77 }
78 }
79 }
80}
81
82impl core::error::Error for MqError {}
83
84const PROTOCOL_VERSION: u8 = 1;
85
86const MSG_SIZE_FIELD_SIZE: usize = core::mem::size_of::<u32>();
87const MSG_CRC_FIELD_SIZE: usize = core::mem::size_of::<u32>();
88const MSG_PROTOCOL_FIELD_SIZE: usize = core::mem::size_of::<u8>();
89
90const CRC32: Crc<u32> = Crc::<u32>::new(&CRC_32_ISO_HDLC);
91
92impl<'a, const LEN: usize> MsgQueue<'a, LEN> {
93 pub fn new(byte_queue: ByteQueue, prefix: &'a [u8], rx_buf: [u8; LEN]) -> Self {
97 Self {
98 byte_queue,
99 prefix,
100 rx_buf,
101 rx_buf_len: 0,
102 has_received_full_msg: false,
103 }
104 }
105
106 fn len_p(&self) -> usize {
112 self.prefix.len()
113 }
114
115 fn len_pv(&self) -> usize {
117 self.len_p() + MSG_PROTOCOL_FIELD_SIZE
118 }
119
120 fn len_pvl(&self) -> usize {
122 self.len_pv() + MSG_SIZE_FIELD_SIZE
123 }
124
125 fn len_pvlc(&self) -> usize {
127 self.len_pvl() + MSG_CRC_FIELD_SIZE
128 }
129
130 fn len_pvlcd(&self, msg_len: usize) -> usize {
132 self.len_pvlc() + msg_len
133 }
134
135 fn len_pvlcdc(&self, msg_len: usize) -> usize {
138 self.len_pvlcd(msg_len) + MSG_CRC_FIELD_SIZE
139 }
140
141 fn read_bytes(&mut self) {
143 let read_bytes_len = self
144 .byte_queue
145 .consume_at_most(&mut self.rx_buf[self.rx_buf_len..]);
146 self.rx_buf_len += read_bytes_len;
147 }
148
149 fn skip_in_rx_buf(&mut self, skip: usize) {
152 assert!(
153 skip <= self.rx_buf_len,
154 "skip rx_buffer value exceeds current rx_buffer length."
155 );
156
157 self.rx_buf.copy_within(skip..self.rx_buf_len, 0);
158 self.rx_buf_len -= skip;
159 }
160
161 fn rm_old_msg(&mut self) {
163 if self.has_received_full_msg {
164 let msg_len = self.try_extract_msg_len().unwrap(); if msg_len <= self.rx_buf_len {
167 self.skip_in_rx_buf(self.len_pvlcdc(msg_len));
169 }
170 self.has_received_full_msg = false;
171 }
172 }
173
174 fn invalidate_current_msg(&mut self) {
175 self.skip_in_rx_buf(1);
176 }
177
178 fn try_advance_to_prefix(&mut self) -> Result<(), MqError> {
181 let mut pos = None;
182
183 for (idx, window) in self.rx_buf[..self.rx_buf_len]
184 .windows(self.prefix.len())
185 .enumerate()
186 {
187 if self.prefix == window {
188 pos = Some(idx);
189 break;
190 }
191 }
192
193 if let Some(idx) = pos {
194 self.skip_in_rx_buf(idx);
195 return Ok(());
196 }
197
198 if self.rx_buf_len >= self.prefix.len() {
203 self.skip_in_rx_buf(self.rx_buf_len - self.prefix.len());
204 }
205
206 Err(MqError::MqEmpty)
207 }
208
209 fn try_extract_msg_len(&self) -> Result<usize, MqError> {
212 if self.rx_buf_len < self.len_pvl() {
213 return Err(MqError::MqEmpty);
214 }
215 let start = self.len_pv();
216 let end = start + MSG_SIZE_FIELD_SIZE;
217 let slice = &self.rx_buf[start..end];
218
219 let mut array = [0u8; MSG_SIZE_FIELD_SIZE];
220 array.copy_from_slice(slice);
221
222 Ok(u32::from_le_bytes(array) as usize)
223 }
224
225 fn verify_msg_packet_len(&mut self, msg_len: usize) -> Result<(), MqError> {
227 if self.rx_buf.len() < self.len_pvlcdc(msg_len) {
229 self.invalidate_current_msg();
230 return Err(MqError::MqMsgTooBig);
231 }
232 Ok(())
233 }
234
235 fn verify_protocol_version(&mut self) -> Result<(), MqError> {
237 if self.rx_buf_len < self.len_pv() {
238 return Err(MqError::MqEmpty);
239 }
240 if self.rx_buf[self.len_p()] != PROTOCOL_VERSION {
241 self.invalidate_current_msg();
242 return Err(MqError::MqWrongProtocolVersion);
243 }
244 Ok(())
245 }
246
247 fn verify_crc(&mut self, crc_start: usize, calculated_crc: u32) -> Result<(), MqError> {
249 if self.rx_buf_len < crc_start + MSG_CRC_FIELD_SIZE {
250 return Err(MqError::MqEmpty); }
252 let crc_end = crc_start + MSG_CRC_FIELD_SIZE;
253 let mut crc_array = [0u8; MSG_CRC_FIELD_SIZE];
254 crc_array.copy_from_slice(&self.rx_buf[crc_start..crc_end]);
255 let received_crc = u32::from_le_bytes(crc_array);
256
257 if received_crc != calculated_crc {
258 self.invalidate_current_msg();
259 return Err(MqError::MqCrcErr);
260 }
261 Ok(())
262 }
263
264 fn verify_full_msg(&mut self) -> Result<usize, MqError> {
267 self.verify_protocol_version()?;
268 let msg_len = self.try_extract_msg_len()?;
269 self.verify_crc(
270 self.len_pvl(),
271 CRC32.checksum(&self.rx_buf[..self.len_pvl()]),
272 )?;
273 self.verify_msg_packet_len(msg_len)?;
277 self.verify_crc(
280 self.len_pvlcd(msg_len),
281 CRC32.checksum(&self.rx_buf[..self.len_pvlcd(msg_len)]),
282 )?;
283
284 Ok(msg_len)
285 }
286
287 fn find_next_msg(&mut self) -> Result<(usize, usize), MqError> {
295 self.rm_old_msg();
296 self.read_bytes();
297 self.try_advance_to_prefix()?;
298 let msg_len = self.verify_full_msg()?;
299 self.has_received_full_msg = true;
300 Ok((self.len_pvlc(), self.len_pvlc() + msg_len))
301 }
302
303 pub fn read_or_fail(&mut self) -> Result<&[u8], MqError> {
307 let (start, end) = self.find_next_msg()?;
308 Ok(&self.rx_buf[start..end])
309 }
310
311 pub fn read_blocking(&mut self) -> Result<&[u8], MqError> {
314 loop {
315 match self.find_next_msg() {
316 Ok((start, end)) => return Ok(&self.rx_buf[start..end]),
317 Err(MqError::MqFull | MqError::MqEmpty) => continue,
318 Err(err) => return Err(err),
319 }
320 }
321 }
322
323 fn wacc(&mut self, digest: &mut Digest<u32>, data: &[u8]) {
329 self.byte_queue.write_or_fail(data).unwrap();
330 digest.update(data);
331 }
332
333 fn write_msg(&mut self, msg_data: &[u8]) -> Result<(), MqError> {
340 let mut header_crc = CRC32.digest();
341 self.wacc(&mut header_crc, self.prefix);
342 self.wacc(&mut header_crc, &PROTOCOL_VERSION.to_le_bytes());
343 let msg_len_u32 = u32::try_from(msg_data.len()).map_err(|_| MqError::MqMsgTooBig)?;
344 self.wacc(&mut header_crc, &msg_len_u32.to_le_bytes());
345 let mut total_crc = header_crc.clone();
348 let header_crc_bytes = header_crc.finalize().to_le_bytes();
349 self.wacc(&mut total_crc, &header_crc_bytes);
350 self.wacc(&mut total_crc, msg_data);
351 let total_crc_bytes = total_crc.finalize().to_le_bytes();
352 self.byte_queue.write_or_fail(&total_crc_bytes).unwrap();
353
354 Ok(())
355 }
356
357 pub fn write_or_fail(&mut self, msg_data: &[u8]) -> Result<(), MqError> {
363 if self.byte_queue.capacity() < self.len_pvlcdc(msg_data.len()) {
364 return Err(MqError::MqMsgTooBig);
365 }
366 if self.byte_queue.space() < self.len_pvlcdc(msg_data.len()) {
367 return Err(MqError::MqFull);
368 }
369 self.write_msg(msg_data)?;
370
371 Ok(())
372 }
373
374 pub fn write_blocking(&mut self, msg_data: &[u8]) -> Result<(), MqError> {
381 if self.byte_queue.capacity() < self.len_pvlcdc(msg_data.len()) {
382 return Err(MqError::MqMsgTooBig);
383 }
384 while self.byte_queue.space() < self.len_pvlcdc(msg_data.len()) {}
385 self.write_msg(msg_data)?;
386 Ok(())
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use crate::byte_queue::ByteQueue;
393 use crate::msg_queue::{
394 MqError, MsgQueue, MSG_CRC_FIELD_SIZE, MSG_PROTOCOL_FIELD_SIZE, MSG_SIZE_FIELD_SIZE,
395 };
396
397 const DEFAULT_PREFIX: &'static [u8] = b"DEFAULT_PREFIX: "; #[test]
400 fn test_skip_in_rx_buf() {
401 let mut bq_buf = [0u32; 64];
402 let mut msg_queue = unsafe {
403 MsgQueue::new(
404 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
405 DEFAULT_PREFIX,
406 [0u8; 64 * 4],
407 )
408 };
409
410 let s = b"abcde";
411 for skip in 0..=s.len() {
412 msg_queue.rx_buf[..s.len()].copy_from_slice(s); msg_queue.rx_buf_len = s.len();
414
415 msg_queue.skip_in_rx_buf(skip);
416 assert_eq!(&msg_queue.rx_buf[..msg_queue.rx_buf_len], &s[skip..]);
417 assert_eq!(msg_queue.rx_buf_len, s.len() - skip);
418 }
419 }
420
421 #[test]
422 fn test_invalid_msg_size() {
423 let mut bq_buf = [0u32; 10]; let mut msg_queue = unsafe {
425 MsgQueue::new(
426 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
427 DEFAULT_PREFIX,
428 [0u8; 7], )
430 };
431
432 let data = b"abcd";
433 let msg_size = DEFAULT_PREFIX.len() + MSG_PROTOCOL_FIELD_SIZE + MSG_SIZE_FIELD_SIZE + MSG_CRC_FIELD_SIZE + data.len() + MSG_CRC_FIELD_SIZE; assert!(msg_queue.byte_queue.capacity() < msg_size);
440 assert_eq!(msg_queue.write_or_fail(data), Err(MqError::MqMsgTooBig));
441
442 let data = b"ab";
443 let msg_size = DEFAULT_PREFIX.len() + MSG_PROTOCOL_FIELD_SIZE + MSG_SIZE_FIELD_SIZE + MSG_CRC_FIELD_SIZE + data.len() + MSG_CRC_FIELD_SIZE; assert!(msg_queue.byte_queue.capacity() == msg_size); assert_eq!(msg_queue.write_or_fail(&[1, 2]), Ok(()));
451 }
452
453 #[test]
454 fn test_read_empty_queue() {
455 let mut bq_buf = [0u32; 64];
456 let mut msg_queue = unsafe {
457 MsgQueue::new(
458 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
459 DEFAULT_PREFIX,
460 [0u8; 64 * 4],
461 )
462 };
463
464 let result = msg_queue.read_or_fail();
465 assert_eq!(result, Err(MqError::MqEmpty));
466 }
467
468 #[test]
469 fn test_write_and_read_msg() {
470 let mut bq_buf = [0u32; 64];
471 let mut msg_queue = unsafe {
472 MsgQueue::new(
473 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
474 DEFAULT_PREFIX,
475 [0u8; 64 * 4],
476 )
477 };
478
479 let msg = b"Hello, World!";
480 let result = msg_queue.write_or_fail(msg);
481 assert!(result.is_ok());
482
483 let read_msg = msg_queue.read_or_fail().unwrap();
484 assert_eq!(read_msg, msg);
485 }
486
487 #[test]
488 fn test_crc_error() {
489 let mut bq_buf = [0u32; 64];
490 let mut msg_queue = unsafe {
491 MsgQueue::new(
492 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
493 DEFAULT_PREFIX,
494 [0u8; 64 * 4],
495 )
496 };
497 let msg = b"xxxxyyyy";
498
499 msg_queue.write_or_fail(msg).unwrap();
501 bq_buf[2 + msg_queue.len_pvl() / 4..].fill(0);
505 assert_eq!(msg_queue.read_or_fail(), Err(MqError::MqCrcErr));
506 msg_queue.write_or_fail(msg).unwrap();
508 assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
509
510 msg_queue = unsafe {
512 MsgQueue::new(
513 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
514 DEFAULT_PREFIX,
515 [0u8; 64 * 4],
516 )
517 };
518 msg_queue.write_or_fail(msg).unwrap();
519 bq_buf[2 + msg_queue.len_pvlcd(msg.len()) / 4..].fill(0);
520 assert_eq!(msg_queue.read_or_fail(), Err(MqError::MqCrcErr));
521 msg_queue.write_blocking(msg).unwrap();
522 assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
523 }
524
525 #[test]
526 fn test_saturate_queue() {
527 let mut bq_buf = [0u32; 64];
528 let mut msg_queue = unsafe {
529 MsgQueue::new(
530 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
531 DEFAULT_PREFIX,
532 [0u8; 64 * 4],
533 )
534 };
535
536 let data = b"abcd";
542 let msg_size = DEFAULT_PREFIX.len()
543 + MSG_PROTOCOL_FIELD_SIZE
544 + MSG_SIZE_FIELD_SIZE
545 + MSG_CRC_FIELD_SIZE
546 + data.len()
547 + MSG_CRC_FIELD_SIZE;
548 let repeat = (bq_buf.len() * 4 - 2 * core::mem::size_of::<u32>() - 1) / msg_size;
549 assert_eq!(repeat, 7);
550
551 for _ in 0..repeat {
552 let result = msg_queue.write_or_fail(data);
553 assert_eq!(result, Ok(()));
554 }
555 assert_eq!(
556 msg_queue.byte_queue.space(),
557 (bq_buf.len() * 4 - 2 * core::mem::size_of::<u32>() - 1 - repeat * msg_size)
558 );
559
560 let result = msg_queue.write_or_fail(data);
561 assert_eq!(result, Err(MqError::MqFull));
562 }
563
564 #[test]
565 fn test_read_after_invalid_msg() {
566 let mut bq_buf = [0u32; 64];
567 let mut msg_queue = unsafe {
568 MsgQueue::new(
569 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
570 DEFAULT_PREFIX,
571 [0u8; 64 * 4],
572 )
573 };
574
575 let msg = b"valid msg";
576 msg_queue.write_or_fail(msg).unwrap();
577
578 msg_queue.read_bytes();
579 msg_queue.invalidate_current_msg();
580
581 let result = msg_queue.read_or_fail();
582 assert_eq!(result, Err(MqError::MqEmpty));
583 }
584
585 #[test]
586 fn test_read_write_after_invalid_msg() {
587 let mut bq_buf = [0u32; 64];
588 let mut msg_queue = unsafe {
589 MsgQueue::new(
590 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
591 DEFAULT_PREFIX,
592 [0u8; 64 * 4],
593 )
594 };
595
596 let msg = b"valid msg";
597 msg_queue.write_or_fail(msg).unwrap();
598 msg_queue.write_or_fail(msg).unwrap();
599
600 msg_queue.read_bytes();
601 msg_queue.invalidate_current_msg();
602
603 let result = msg_queue.read_or_fail().unwrap();
604 assert_eq!(result, msg);
605 }
606
607 #[test]
608 fn test_blocking_read_msg() {
609 let mut bq_buf = [0u32; 64];
610 let mut msg_queue = unsafe {
611 MsgQueue::new(
612 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
613 DEFAULT_PREFIX,
614 [0u8; 64 * 4],
615 )
616 };
617 let msg = b"Blocking Msg";
618
619 msg_queue.write_blocking(msg).unwrap();
620
621 let read_msg = msg_queue.read_blocking().unwrap();
622 assert_eq!(read_msg, msg);
623 }
624
625 #[test]
626 fn test_read_part_of_next_msg() {
627 let mut bq_buf = [0u32; 128];
628 let mut msg_queue = unsafe {
629 MsgQueue::new(
630 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
631 DEFAULT_PREFIX,
632 [0u8; 64], )
634 };
635
636 let msg = b"valid msg";
637 let garbage = [0xff; 128];
638
639 for garbage_len in 64 - 20..64 {
640 msg_queue
642 .byte_queue
643 .write_or_fail(&garbage[..garbage_len])
644 .unwrap();
645 msg_queue.write_or_fail(msg).unwrap();
646 assert_eq!(msg_queue.read_or_fail(), Err(MqError::MqEmpty));
649 assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
651 }
652 }
653
654 #[test]
655 fn test_incompatible_protocol_version() {
656 let mut bq_buf = [0u32; 128];
657 let mut msg_queue = unsafe {
658 MsgQueue::new(
659 ByteQueue::create(bq_buf.as_mut_ptr() as *mut u8, bq_buf.len() * 4),
660 DEFAULT_PREFIX,
661 [0u8; 64 * 4],
662 )
663 };
664
665 let msg = b"xxxxyyyy";
666 msg_queue.write_blocking(msg).unwrap();
667 let u8_slice: &mut [u8] =
668 unsafe { std::slice::from_raw_parts_mut(bq_buf.as_mut_ptr() as *mut u8, 128 * 4) };
669 u8_slice[8 + msg_queue.len_p()] = 2;
670 assert_eq!(
671 msg_queue.read_or_fail(),
672 Err(MqError::MqWrongProtocolVersion)
673 );
674 msg_queue.write_blocking(msg).unwrap();
676 assert_eq!(msg_queue.read_or_fail().unwrap(), msg);
677 }
678}