secure_serial/
receiver.rs1use embassy_futures::select::{Either, select};
7use embassy_sync::{blocking_mutex::raw::RawMutex, channel};
8use embedded_buffer_pool::{BufferGuard, BufferPool, MappedBufferGuard};
9use heapless::Vec;
10
11use crate::protocol::{
12 Ack, CHUNK_LEN_MAX, CHUNK_PAYLOAD_MAX, MAGIC, MAGIC_0, MAGIC_1, PACKET_ACK, PACKET_DATA,
13};
14use crate::transport::{CrcDevice, TransportRead, TransportWrite};
15
16#[cfg(feature = "defmt")]
17macro_rules! log_warn {
18 ($($arg:tt)*) => {
19 defmt::warn!($($arg)*)
20 };
21}
22
23#[cfg(not(feature = "defmt"))]
24macro_rules! log_warn {
25 ($($arg:tt)*) => {};
26}
27
28#[cfg(feature = "defmt")]
29macro_rules! log_error {
30 ($($arg:tt)*) => {
31 defmt::error!($($arg)*)
32 };
33}
34
35#[cfg(not(feature = "defmt"))]
36macro_rules! log_error {
37 ($($arg:tt)*) => {};
38}
39
40struct RxPacket<M: RawMutex + 'static, const N_BUF: usize> {
41 buffer: BufferGuard<M, [u8; N_BUF]>,
42 packet_id: u16,
43 packet_length: usize,
44 buffer_written: [u32; 4], buffer_written_count: usize,
46}
47
48pub async fn run_read<
53 M: RawMutex + 'static,
54 T: TransportRead,
55 const N_INFLIGHT: usize,
56 const N_POOL: usize,
57 const N_BUF: usize,
58>(
59 transport: &mut T,
60 crc_dev: &mut impl CrcDevice,
61 buffer_pool: &'static BufferPool<M, [u8; N_BUF], N_POOL>,
62 rx_queue: channel::Sender<'_, M, MappedBufferGuard<M, [u8]>, N_POOL>,
63 acks_to_send: channel::Sender<'_, M, Ack, N_INFLIGHT>,
64 acks_received: channel::Sender<'_, M, Ack, N_INFLIGHT>,
65) -> Result<(), T::Error> {
66 let mut chunk_buffer = [0; 2 * CHUNK_LEN_MAX + 4];
67 let mut chunk_buffer_count = 0;
68 let mut rx_packet: Option<RxPacket<M, N_BUF>> = None;
70 let mut last_successfully_received_packet: Option<u16> = None;
72 'outer: loop {
73 if chunk_buffer_count > (chunk_buffer.len() - CHUNK_LEN_MAX) {
75 log_error!(
76 "chunk buffer overflow - was not cleared in previous loop iteration. please report this as a bug."
77 );
78 chunk_buffer_count = 0;
79 }
80 chunk_buffer_count += transport
82 .read(&mut chunk_buffer[chunk_buffer_count..][..CHUNK_LEN_MAX])
83 .await?;
84
85 let mut buffer_start = 0;
86 loop {
87 let data_valid = &chunk_buffer[buffer_start..chunk_buffer_count];
89 let Some(index_start) = data_valid.iter().position(|&v| v == MAGIC_0) else {
90 chunk_buffer_count = 0;
92 continue 'outer;
93 };
94
95 buffer_start += index_start;
97 let data_valid = &chunk_buffer[buffer_start..chunk_buffer_count];
98 debug_assert!(data_valid[0] == MAGIC_0);
99
100 if data_valid.len() < (2 + 1 + 1 + 4) {
103 break;
105 }
106
107 if data_valid[1] != MAGIC_1 {
109 buffer_start += 1;
111 continue;
112 }
113
114 let chunk_type = data_valid[3];
116 match chunk_type {
117 PACKET_DATA | PACKET_ACK => (),
119 _ => {
120 buffer_start += 1;
122 continue;
123 }
124 }
125
126 let chunk_length = data_valid[2] as usize;
128 if chunk_length > CHUNK_LEN_MAX {
129 buffer_start += 1;
131 continue;
132 }
133
134 if data_valid.len() < (chunk_length + 4) {
136 break;
138 }
139
140 let crc_calc = crc_dev.crc(&data_valid[..chunk_length]).await;
142 let crc_read = u32::from_le_bytes(data_valid[chunk_length..][..4].try_into().unwrap());
143 if crc_calc != crc_read {
144 buffer_start += 1;
146 log_warn!("received chunk with invalid crc");
147 continue;
148 }
149
150 if chunk_length < 4 {
152 log_warn!("chunk length too short for header fields");
153 buffer_start += chunk_length + 4;
154 continue;
155 }
156
157 let buffer_chunk = &data_valid[4..chunk_length];
158 buffer_start += chunk_length + 4;
160
161 match chunk_type {
162 PACKET_DATA => {
163 const DATA_HEADER_BODY_LEN: usize = 2 + 4 + 4;
165 if buffer_chunk.len() < DATA_HEADER_BODY_LEN {
166 log_warn!("DATA chunk length too short for fixed header");
167 continue;
168 }
169
170 let packet_id = u16::from_le_bytes(buffer_chunk[0..2].try_into().unwrap());
171 let packet_length =
172 u32::from_le_bytes(buffer_chunk[2..6].try_into().unwrap()) as usize;
173 let chunk_offset =
174 u32::from_le_bytes(buffer_chunk[6..10].try_into().unwrap()) as usize;
175
176 let payload = &buffer_chunk[10..];
177
178 if packet_length > N_BUF {
179 log_warn!("received a chunk belonging to a packet that exceeds N_BUF");
180 continue;
181 }
182
183 if (chunk_offset * CHUNK_PAYLOAD_MAX + payload.len()) > packet_length {
184 log_warn!("received a chunk that exceeds its packet's length");
185 continue;
186 }
187
188 let payload_length_expected =
189 (packet_length - chunk_offset * CHUNK_PAYLOAD_MAX).min(CHUNK_PAYLOAD_MAX);
190 if payload.len() != payload_length_expected {
191 log_warn!(
192 "chunk payload length ({}) does not match chunk offset {} and packet length {}",
193 payload.len(),
194 chunk_offset,
195 packet_length
196 );
197 continue;
198 }
199
200 if let Some(packet_id_last) = last_successfully_received_packet
202 && packet_id_last == packet_id
203 {
204 acks_to_send
206 .try_send(Ack {
207 packet_id,
208 chunk_offset: chunk_offset as u32,
209 })
210 .ok();
211 continue;
212 }
213
214 if let Some(rxp) = rx_packet.as_ref()
216 && rxp.packet_id != packet_id
217 {
218 rx_packet = None;
219 }
220
221 if rx_packet.is_none() {
223 rx_packet = buffer_pool.try_take().map(|buf| RxPacket {
225 buffer: buf,
226 packet_id,
227 packet_length,
228 buffer_written: [0; _],
229 buffer_written_count: 0,
230 })
231 }
232
233 let Some(rxp) = rx_packet.as_mut() else {
235 log_warn!(
236 "could not allocate a buffer for new packet with id {} and length {}",
237 packet_id,
238 packet_length,
239 );
240 continue;
241 };
242
243 acks_to_send
245 .try_send(Ack {
246 packet_id,
247 chunk_offset: chunk_offset as u32,
248 })
249 .ok();
250
251 let rx_packet_buffer = &mut *rxp.buffer;
253 rx_packet_buffer[chunk_offset * CHUNK_PAYLOAD_MAX..][..payload.len()]
254 .copy_from_slice(payload);
255
256 let id_num = chunk_offset / 32;
258 let id_bit = chunk_offset % 32;
259 let buffer_written = &mut rxp.buffer_written[id_num];
260 if (*buffer_written & (1 << id_bit)) == 0 {
261 rxp.buffer_written_count += 1;
262 }
263 *buffer_written |= 1 << id_bit;
264
265 let num_chunks = rxp.packet_length.div_ceil(CHUNK_PAYLOAD_MAX);
267 if rxp.buffer_written_count == num_chunks {
268 let length = rxp.packet_length;
269 let rx_packet = rx_packet.take().unwrap();
270 rx_queue
271 .send(BufferGuard::map(rx_packet.buffer, |buf| &mut buf[..length]))
272 .await;
273 last_successfully_received_packet = Some(rx_packet.packet_id);
274 }
275 }
276 PACKET_ACK => {
277 let mut buf = buffer_chunk;
278 while buf.len() >= 6 {
279 let ack = Ack::from_buffer(buf[..6].try_into().unwrap());
280 acks_received.try_send(ack).ok();
281 buf = &buf[6..];
282 }
283 }
284 _t => {
285 log_warn!("received unknown packet type {:#02X}", _t);
286 continue;
287 }
288 }
289 }
290
291 if buffer_start == chunk_buffer_count {
293 chunk_buffer_count = 0;
295 } else if buffer_start != 0 {
296 chunk_buffer.copy_within(buffer_start..chunk_buffer_count, 0);
297 chunk_buffer_count -= buffer_start;
298 }
299 }
300}
301
302pub async fn run_write<M: RawMutex + 'static, T: TransportWrite, const N_INFLIGHT: usize>(
305 transport: &mut T,
306 tx_queue: &mut channel::Receiver<'_, M, BufferGuard<M, Vec<u8, CHUNK_LEN_MAX>>, N_INFLIGHT>,
307 ack_queue: &mut channel::Receiver<'_, M, Ack, N_INFLIGHT>,
308 crc_dev: &mut impl CrcDevice,
309) -> Result<(), T::Error> {
310 let mut ack_buf = Vec::<u8, CHUNK_LEN_MAX>::new();
311 loop {
312 match select(ack_queue.receive(), tx_queue.receive()).await {
313 Either::First(ack) => {
314 ack_buf.clear();
315 ack_buf.extend_from_slice(&MAGIC).ok();
317 let idx_len = ack_buf.len();
319 ack_buf.push(0).ok();
320 ack_buf.push(PACKET_ACK).ok();
322 ack_buf.extend_from_slice(&ack.to_buffer()).ok();
324 while (ack_buf.capacity() - ack_buf.len()) >= (6 + 4)
326 && let Ok(ack) = ack_queue.try_receive()
327 {
328 ack_buf.extend_from_slice(&ack.to_buffer()).ok();
329 }
330 ack_buf[idx_len] = ack_buf.len() as u8;
331
332 let crc = crc_dev.crc(&ack_buf).await;
333 ack_buf.extend_from_slice(&crc.to_le_bytes()).ok();
334 transport.write(&ack_buf).await?;
335 }
336 Either::Second(mut tx_buffer) => {
337 let crc = crc_dev.crc(&tx_buffer).await;
338 tx_buffer.extend_from_slice(&crc.to_le_bytes()).ok();
339 transport.write(&tx_buffer).await?;
340 }
341 }
342 }
343}