rusty_dtls/
lib.rs

1#![no_std]
2
3use core::{borrow::BorrowMut, marker::PhantomData, mem, net::SocketAddr, ops::Range};
4
5use crypto::TrafficSecret;
6use handshake::{
7    handle_handshake_message_client, handle_handshake_message_server, ClientState,
8    CryptoInformation, HandshakeContext, HandshakeInformation, ServerState,
9};
10use log::{debug, info, trace};
11use parsing::{
12    encode_alert, encode_hello_retry, parse_alert, parse_client_hello_first_pass,
13    parse_client_hello_second_pass, ClientHelloResult, EncodeAck, EncodeHandshakeMessage,
14    HandshakeType, HelloRetryCookie, ParseHandshakeMessage,
15};
16
17pub use crypto::{HashFunction, Psk};
18
19use parsing_utility::ParseBuffer;
20use record_parsing::{
21    parse_plaintext_record, parse_record, EncodeCiphertextRecord, EncodePlaintextRecord,
22    RecordContentType,
23};
24
25mod fmt;
26
27#[cfg(feature = "async")]
28mod asynchronous;
29#[cfg(feature = "async")]
30pub use asynchronous::{DtlsStackAsync, Event};
31
32mod sync;
33pub use sync::DtlsStack;
34
35mod buffer_record_queue;
36mod crypto;
37mod handshake;
38mod parsing;
39mod parsing_utility;
40mod record_parsing;
41
42type Epoch = u64;
43type EpochShort = u8;
44
45type TimeStampMs = u64;
46
47type HandshakeSeqNum = u16;
48
49type RecordSeqNum = u64;
50type RecordSeqNumShort = u8;
51
52type Connections<'a> = [Option<DtlsConnection<'a>>];
53type RecordQueue<'a> = buffer_record_queue::BufferMessageQueue<'a>;
54
55#[derive(Debug, Clone, Copy)]
56#[repr(u8)]
57pub enum AlertDescription {
58    CloseNotify = 0,
59    UnexpectedMessage = 10,
60    IllegalParameter = 47,
61    DecodeError = 50,
62    DecryptionError = 51,
63    MissingExtension = 109,
64    UnsupportedExtension = 110,
65    Unknown,
66}
67
68impl From<u8> for AlertDescription {
69    fn from(value: u8) -> Self {
70        match value {
71            0 => AlertDescription::CloseNotify,
72            10 => AlertDescription::UnexpectedMessage,
73            47 => AlertDescription::IllegalParameter,
74            50 => AlertDescription::DecodeError,
75            51 => AlertDescription::DecryptionError,
76            109 => AlertDescription::MissingExtension,
77            110 => AlertDescription::UnsupportedExtension,
78            _ => AlertDescription::Unknown,
79        }
80    }
81}
82
83impl AlertDescription {
84    pub fn alert_level(&self) -> AlertLevel {
85        match self {
86            AlertDescription::UnexpectedMessage
87            | AlertDescription::IllegalParameter
88            | AlertDescription::DecodeError
89            | AlertDescription::DecryptionError
90            | AlertDescription::MissingExtension
91            | AlertDescription::UnsupportedExtension
92            | AlertDescription::CloseNotify
93            | AlertDescription::Unknown => AlertLevel::Fatal,
94        }
95    }
96}
97
98#[derive(Debug, PartialEq, Eq)]
99#[repr(u8)]
100pub enum AlertLevel {
101    Warning = 1,
102    Fatal = 2,
103}
104
105impl From<u8> for AlertLevel {
106    fn from(value: u8) -> Self {
107        match value {
108            1 => AlertLevel::Warning,
109            _ => AlertLevel::Fatal,
110        }
111    }
112}
113
114#[derive(Debug)]
115pub enum DtlsError {
116    MaximumConnectionsReached,
117    UnknownConnection,
118    MaximumRetransmissionsReached,
119    HandshakeAlreadyRunning,
120    OutOfMemory,
121    /// Indicates a bug in the implementation
122    IllegalInnerState,
123    IoError,
124    RngError,
125    ParseError,
126    CryptoError,
127    NoMatchingEpoch,
128    RejectedSequenceNumber,
129    Alert(AlertDescription),
130    MultipleRecordsPerPacketNotSupported,
131}
132
133#[derive(PartialEq, Eq, Clone, Copy)]
134pub struct ConnectionId(usize);
135
136impl core::fmt::Debug for ConnectionId {
137    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
138        self.0.fmt(f)
139    }
140}
141
142struct DtlsConnection<'a> {
143    epochs: heapless::Vec<EpochState, 4>,
144    current_epoch: Epoch,
145    pub addr: SocketAddr,
146    handshake_finished: bool,
147    p: PhantomData<&'a ()>,
148}
149
150#[derive(Debug, PartialEq, Eq)]
151pub enum DtlsPoll {
152    /// Wait at most until a new message has arrived or the timeout has elapsed before
153    /// calling [`poll`] again.
154    WaitTimeoutMs(u32),
155    /// Wait at most until a new message has arrived before calling [`poll`] again.
156    Wait,
157    /// Indicates a finished handshake.
158    /// Poll should be called again immediately after resetting the handshake slot.
159    FinishedHandshake,
160}
161
162impl DtlsPoll {
163    pub fn merge(self, other: Self) -> Self {
164        match (self, other) {
165            (DtlsPoll::FinishedHandshake, _) | (_, DtlsPoll::FinishedHandshake) => {
166                DtlsPoll::FinishedHandshake
167            }
168            (DtlsPoll::WaitTimeoutMs(t1), DtlsPoll::WaitTimeoutMs(t2)) => {
169                DtlsPoll::WaitTimeoutMs(t1.min(t2))
170            }
171            (DtlsPoll::WaitTimeoutMs(t), _) | (_, DtlsPoll::WaitTimeoutMs(t)) => {
172                DtlsPoll::WaitTimeoutMs(t)
173            }
174            (DtlsPoll::Wait, DtlsPoll::Wait) => DtlsPoll::Wait,
175        }
176    }
177}
178
179enum DeferredAction<'a> {
180    None,
181    Send(&'a [u8]),
182    AppData(ConnectionId, Range<usize>),
183    Unhandled,
184}
185
186fn try_pass_packet_to_connection<'a>(
187    staging_buffer: &'a mut [u8],
188    connections: &mut [Option<DtlsConnection>],
189    addr: &SocketAddr,
190    packet_len: usize,
191) -> Result<DeferredAction<'a>, DtlsError> {
192    for i in 0..connections.len() {
193        let connection = match connections[i].as_mut() {
194            Some(c) if &c.addr == addr && c.handshake_finished => c,
195            _ => continue,
196        };
197        let mut packet_buffer = ParseBuffer::init(&mut staging_buffer[..packet_len]);
198        let res = parse_record(&mut packet_buffer, &mut connection.epochs);
199        let action = match res {
200            Ok(RecordContentType::ApplicationData) => DeferredAction::AppData(
201                ConnectionId(i),
202                packet_buffer.offset()..packet_buffer.capacity(),
203            ),
204            Ok(RecordContentType::DtlsHandshake) => {
205                trace!("Received handshake message on existing connection");
206                match ParseHandshakeMessage::retrieve_content_type(&mut packet_buffer) {
207                    // The first ack might have gone lost
208                    Ok(
209                        HandshakeType::ServerHello
210                        | HandshakeType::EncryptedExtension
211                        | HandshakeType::Finished,
212                    ) if connection.current_epoch < 6 => {
213                        debug!("Found retransmitted handshake message. Resending ack.");
214                        DeferredAction::Send(stage_ack(
215                            staging_buffer,
216                            &mut connection.epochs,
217                            &3,
218                            &2,
219                        )?)
220                    }
221                    _ => close_connection(ConnectionId(i), staging_buffer, connections),
222                }
223            }
224            Ok(RecordContentType::Ack) => DeferredAction::None,
225            Ok(_) => close_connection(ConnectionId(i), staging_buffer, connections),
226            Err(err) => {
227                trace!("Received broken record: {:?}", err);
228                if let DtlsError::Alert(alert) = err {
229                    DeferredAction::Send(stage_alert(
230                        staging_buffer,
231                        &mut connection.epochs,
232                        &connection.current_epoch,
233                        alert,
234                    )?)
235                } else {
236                    DeferredAction::None
237                }
238            }
239        };
240        return Ok(action);
241    }
242    Ok(DeferredAction::Unhandled)
243}
244
245fn close_connection<'a>(
246    connection_id: ConnectionId,
247    staging_buffer: &'a mut [u8],
248    connections: &mut [Option<DtlsConnection>],
249) -> DeferredAction<'a> {
250    debug_assert!(connection_id.0 < connections.len());
251    let mut action = DeferredAction::None;
252    if connection_id.0 < connections.len() {
253        if let Some(c) = connections[connection_id.0].as_mut() {
254            if !c.handshake_finished {
255                return action;
256            }
257            if let Ok(buf) = stage_alert(
258                staging_buffer,
259                &mut c.epochs,
260                &c.current_epoch,
261                AlertDescription::CloseNotify,
262            ) {
263                action = DeferredAction::Send(buf);
264            }
265        }
266        connections[connection_id.0] = None;
267    }
268    action
269}
270
271fn try_pass_packet_to_handshake<'a>(
272    staging_buffer: &'a mut [u8],
273    connections: &mut [Option<DtlsConnection>],
274    handshakes: &mut [HandshakeSlot],
275    addr: &SocketAddr,
276    packet_len: usize,
277) -> Result<DeferredAction<'a>, DtlsError> {
278    for handshake in handshakes {
279        if let HandshakeSlotState::Running {
280            state,
281            handshake: ctx,
282        } = &mut handshake.state
283        {
284            let connection = ctx.connection(connections);
285            if &connection.addr != addr {
286                continue;
287            }
288            let Some((content_type, mut packet)) =
289                try_unpack_record(&mut staging_buffer[..packet_len], &mut connection.epochs)?
290            else {
291                return Ok(DeferredAction::None);
292            };
293            let mut new_state = *state;
294            if content_type == RecordContentType::Alert {
295                let (level, desc) = parse_alert(&mut packet)?;
296                info!("Received alert: {:?}, {:?}", level, desc);
297                handshake.close(connections);
298                continue;
299            }
300            // Implicit way to ack the client finished
301            else if content_type == RecordContentType::ApplicationData
302                && matches!(state, HandshakeState::Client(ClientState::WaitServerAck))
303            {
304                debug!("[Client] Acked client finished through app data");
305                let start = packet.offset();
306                let end = packet.capacity();
307                let id = ctx.conn_id;
308                handshake.finish_handshake(connection);
309                return Ok(DeferredAction::AppData(ConnectionId(id), start..end));
310            } else {
311                let res = match &mut new_state {
312                    HandshakeState::Client(state) => handle_handshake_message_client(
313                        state,
314                        ctx,
315                        &mut handshake.rt_queue,
316                        connection,
317                        content_type,
318                        packet,
319                    ),
320                    HandshakeState::Server(state) => handle_handshake_message_server(
321                        state,
322                        ctx,
323                        &mut handshake.rt_queue,
324                        connection,
325                        content_type,
326                        packet,
327                    ),
328                };
329                if let Err(DtlsError::Alert(alert)) = res {
330                    let buf = stage_alert(
331                        staging_buffer,
332                        &mut connection.epochs,
333                        &connection.current_epoch,
334                        alert,
335                    )?;
336                    return Ok(DeferredAction::Send(buf));
337                }
338                res?;
339                *state = new_state;
340                // Send ack for client finish
341                if matches!(
342                    state,
343                    HandshakeState::Server(ServerState::FinishedHandshake)
344                ) {
345                    debug!("[Server] Send ACK for client finish");
346                    return Ok(DeferredAction::Send(stage_ack(
347                        staging_buffer,
348                        &mut connection.epochs,
349                        &3,
350                        &2,
351                    )?));
352                }
353            }
354            return Ok(DeferredAction::None);
355        }
356    }
357    Ok(DeferredAction::Unhandled)
358}
359
360fn create_handshake_connection<'a, 'b>(
361    connections: &'a mut [Option<DtlsConnection<'b>>],
362    addr: &SocketAddr,
363) -> Result<(usize, &'a mut DtlsConnection<'b>), DtlsError> {
364    let slot = find_empty_connection_slot(connections);
365    if let Some(slot) = slot {
366        connections[slot] = Some(DtlsConnection {
367            epochs: heapless::Vec::new(),
368            current_epoch: 0,
369            addr: *addr,
370            handshake_finished: false,
371            p: PhantomData,
372        });
373        let _ = connections[slot]
374            .as_mut()
375            .unwrap()
376            .epochs
377            .push(EpochState::empty());
378        Ok((slot, connections[slot].as_mut().unwrap()))
379    } else {
380        Err(DtlsError::MaximumConnectionsReached)
381    }
382}
383
384fn open_connection(
385    connections: &mut Connections,
386    slot: &mut HandshakeSlot,
387    addr: &SocketAddr,
388) -> bool {
389    let Ok((conn_id, _)) = create_handshake_connection(connections, addr) else {
390        return false;
391    };
392
393    let HandshakeSlotState::Empty = slot.state else {
394        return false;
395    };
396    slot.state = HandshakeSlotState::Running {
397        state: HandshakeState::Client(ClientState::default()),
398        handshake: HandshakeContext {
399            recv_handshake_seq_num: 0,
400            send_handshake_seq_num: 0,
401            conn_id,
402            info: HandshakeInformation {
403                available_psks: slot.psks,
404                selected_psk: None,
405                crypto: CryptoInformation::new(),
406                selected_cipher_suite: None,
407                received_hello_retry_request: false,
408            },
409        },
410    };
411
412    true
413}
414
415fn find_empty_connection_slot(connections: &mut [Option<DtlsConnection>]) -> Option<usize> {
416    for (i, c) in connections.iter().enumerate() {
417        if c.is_none() {
418            return Some(i);
419        }
420    }
421    None
422}
423
424fn try_open_new_handshake<'a>(
425    staging_buffer: &'a mut [u8],
426    require_cookie: bool,
427    cookie_key: &[u8],
428    handshakes: &mut [HandshakeSlot],
429    connections: &mut [Option<DtlsConnection>],
430    addr: &SocketAddr,
431    packet_len: usize,
432) -> Result<Option<&'a [u8]>, DtlsError> {
433    let mut packet_buffer = ParseBuffer::init(&mut staging_buffer[..packet_len]);
434    let mut epoch_states = [EpochState::empty()];
435    let res = parse_plaintext_record(&mut packet_buffer, &mut epoch_states);
436    let Ok(RecordContentType::DtlsHandshake) = res else {
437        return Ok(None);
438    };
439    let mut send_buf = None;
440    for handshake_slot in handshakes {
441        if !matches!(handshake_slot.state, HandshakeSlotState::Empty) {
442            continue;
443        }
444        let Ok((conn_id, conn)) = create_handshake_connection(connections, addr) else {
445            return Ok(None);
446        };
447        conn.epochs[0] = epoch_states.into_iter().next().unwrap();
448
449        handshake_slot.fill(conn_id);
450        let HandshakeSlotState::Running {
451            state: _,
452            handshake: ctx,
453        } = &mut handshake_slot.state
454        else {
455            unreachable!()
456        };
457        let Ok((mut client_hello, HandshakeType::ClientHello, client_hello_seq_num @ (0 | 1))) =
458            ParseHandshakeMessage::new(packet_buffer)
459        else {
460            break;
461        };
462        let client_hello_start = client_hello.payload_buffer().offset();
463        match parse_client_hello_first_pass(
464            client_hello.payload_buffer(),
465            require_cookie,
466            cookie_key,
467            addr,
468            ctx,
469            &mut handshake_slot.rt_queue,
470        ) {
471            Ok(ClientHelloResult::MissingCookie) => {
472                debug!("[Server] Didn't find valid cookie. Sending hello_retry");
473                client_hello.add_to_transcript_hash(&mut ctx.info.crypto);
474                send_buf = Some(stage_hello_retry_message(
475                    staging_buffer,
476                    cookie_key,
477                    addr,
478                    &mut ctx.info,
479                )?);
480            }
481            Ok(ClientHelloResult::Ok) => {
482                if require_cookie {
483                    debug!("[Server] Found valid cookie opening handshake");
484                }
485                parse_client_hello_second_pass(
486                    client_hello.payload_buffer(),
487                    &mut ctx.info,
488                    client_hello_start,
489                )?;
490                client_hello.add_to_transcript_hash(&mut ctx.info.crypto);
491                conn.epochs[0].send_record_seq_num = client_hello_seq_num as u64;
492                ctx.send_handshake_seq_num = client_hello_seq_num as u8;
493                ctx.recv_handshake_seq_num = client_hello_seq_num as u8 + 1;
494                break;
495            }
496            Err(err) => {
497                debug!("[Server] Error parsing client_hello: {err:?}");
498            }
499        }
500        handshake_slot.close(connections);
501        connections[conn_id] = None;
502        break;
503    }
504    Ok(send_buf)
505}
506
507fn stage_hello_retry_message<'a>(
508    staging_buffer: &'a mut [u8],
509    cookie_key: &[u8],
510    addr: &SocketAddr,
511    info: &mut HandshakeInformation,
512) -> Result<&'a [u8], DtlsError> {
513    let mut buffer = ParseBuffer::init(staging_buffer.borrow_mut());
514    let mut record = EncodePlaintextRecord::new(&mut buffer, RecordContentType::DtlsHandshake, 0)?;
515    let mut handshake =
516        EncodeHandshakeMessage::new(record.payload_buffer(), HandshakeType::ServerHello, 0)?;
517    encode_hello_retry(
518        handshake.payload_buffer(),
519        &[],
520        info.selected_cipher_suite
521            .ok_or(DtlsError::IllegalInnerState)?,
522        HelloRetryCookie::calculate(info.crypto.psk_hash_mut()?, cookie_key, addr),
523    )?;
524    handshake.finish(&mut info.crypto);
525    record.finish();
526    let offset = buffer.offset();
527    Ok(&buffer.release_buffer()[..offset])
528}
529
530fn stage_ack<'a>(
531    staging_buffer: &'a mut [u8],
532    epoch_states: &mut [EpochState],
533    epoch: &u64,
534    ack_epoch: &u64,
535) -> Result<&'a [u8], DtlsError> {
536    let mut buffer = ParseBuffer::init(staging_buffer);
537    let send_epoch_index = *epoch as usize & 3;
538    let ack_epoch_index = *ack_epoch as usize & 3;
539    let max_entries = (buffer.capacity() as u64 - 2) / 16;
540    let mut record =
541        EncodeCiphertextRecord::new(&mut buffer, &epoch_states[send_epoch_index], epoch)?;
542    let mut ack = EncodeAck::new(record.payload_buffer())?;
543    let w = &epoch_states[ack_epoch_index].sliding_window;
544    let r = &epoch_states[ack_epoch_index].receive_record_seq_num;
545    let mut index = 1;
546    for i in 0..64.min(max_entries) {
547        if w & index > 0 {
548            let s = r - i;
549            ack.add_entry(ack_epoch, &s)?;
550        }
551        index <<= 1;
552    }
553    ack.finish();
554    record.finish(&mut epoch_states[send_epoch_index], RecordContentType::Ack)?;
555    let offset = buffer.offset();
556    Ok(&buffer.release_buffer()[..offset])
557}
558
559fn stage_alert<'a>(
560    staging_buffer: &'a mut [u8],
561    epoch_states: &mut [EpochState],
562    epoch: &u64,
563    alert: AlertDescription,
564) -> Result<&'a [u8], DtlsError> {
565    info!("Sending alert: {:?}", alert);
566    let epoch_index = *epoch as usize & 3;
567    let mut buffer = ParseBuffer::init(staging_buffer);
568    if epoch < &2 {
569        let mut record = EncodePlaintextRecord::new(
570            &mut buffer,
571            RecordContentType::Alert,
572            epoch_states[epoch_index].send_record_seq_num,
573        )?;
574        encode_alert(record.payload_buffer(), alert, alert.alert_level())?;
575        record.finish();
576    } else {
577        let mut record =
578            EncodeCiphertextRecord::new(&mut buffer, &epoch_states[epoch_index], epoch)?;
579        encode_alert(record.payload_buffer(), alert, alert.alert_level())?;
580        record.finish(&mut epoch_states[epoch_index], RecordContentType::Alert)?;
581    }
582    let offset = buffer.offset();
583    Ok(&buffer.release_buffer()[..offset])
584}
585
586pub struct HandshakeSlot<'a> {
587    rt_queue: RecordQueue<'a>,
588    psks: &'a [Psk<'a>],
589    state: HandshakeSlotState<'a>,
590}
591
592#[derive(Default)]
593pub enum HandshakeSlotState<'a> {
594    Running {
595        state: HandshakeState,
596        handshake: HandshakeContext<'a>,
597    },
598    #[default]
599    Empty,
600    Finished(ConnectionId),
601}
602
603#[derive(Clone, Copy)]
604pub enum HandshakeState {
605    Client(ClientState),
606    Server(ServerState),
607}
608
609impl<'a> HandshakeSlot<'a> {
610    pub fn new(available_psks: &'a [Psk<'a>], buffer: &'a mut [u8]) -> Self {
611        HandshakeSlot {
612            rt_queue: RecordQueue::new(buffer),
613            psks: available_psks,
614            state: HandshakeSlotState::Empty,
615        }
616    }
617
618    fn fill(&mut self, conn_id: usize) {
619        if let HandshakeSlotState::Empty = self.state {
620            self.state = HandshakeSlotState::Running {
621                state: HandshakeState::Server(ServerState::default()),
622                handshake: HandshakeContext {
623                    recv_handshake_seq_num: 0,
624                    send_handshake_seq_num: 0,
625                    conn_id,
626                    info: HandshakeInformation {
627                        received_hello_retry_request: false,
628                        available_psks: self.psks,
629                        selected_psk: None,
630                        crypto: CryptoInformation::new(),
631                        selected_cipher_suite: None,
632                    },
633                },
634            }
635        }
636    }
637
638    pub fn try_take_connection_id(&mut self) -> Option<ConnectionId> {
639        if let HandshakeSlotState::Finished(cid) = self.state {
640            self.state = HandshakeSlotState::Empty;
641            Some(cid)
642        } else {
643            None
644        }
645    }
646
647    fn finish_handshake(&mut self, conn: &mut DtlsConnection) {
648        if let HandshakeSlotState::Running {
649            state: _,
650            handshake: ctx,
651        } = mem::take(&mut self.state)
652        {
653            conn.handshake_finished = true;
654            self.rt_queue.reset();
655            let id = ctx.conn_id;
656            self.state = HandshakeSlotState::Finished(ConnectionId(id));
657        }
658    }
659
660    fn close(&mut self, connections: &mut [Option<DtlsConnection>]) {
661        debug!("Closing handshake prematurely");
662        match mem::take(&mut self.state) {
663            HandshakeSlotState::Running {
664                state: _,
665                handshake: c,
666            } => {
667                connections[c.conn_id] = None;
668            }
669            HandshakeSlotState::Empty | HandshakeSlotState::Finished(_) => {}
670        }
671        self.rt_queue.reset();
672    }
673}
674
675fn try_unpack_record<'a>(
676    packet: &'a mut [u8],
677    viable_epochs: &mut [EpochState],
678) -> Result<Option<(RecordContentType, ParseBuffer<'a>)>, DtlsError> {
679    let mut packet_buffer = ParseBuffer::init(packet);
680    let res = parse_record(&mut packet_buffer, viable_epochs);
681
682    match res {
683        Err(DtlsError::NoMatchingEpoch) => {
684            trace!("Rejected record because no cipher state was present for its epoch");
685            Ok(None)
686        }
687        Err(DtlsError::RejectedSequenceNumber) => {
688            trace!("Rejected record because it was already received");
689            Ok(None)
690        }
691        Err(DtlsError::ParseError | DtlsError::CryptoError) => {
692            trace!("Rejected record because it was broken");
693            Ok(None)
694        }
695        Err(err) => Err(err),
696        Ok(content_type) => Ok(Some((content_type, packet_buffer))),
697    }
698}
699
700struct EpochState {
701    send_record_seq_num: RecordSeqNum,
702    receive_record_seq_num: RecordSeqNum,
703    read_traffic_secret: TrafficSecret,
704    write_traffic_secret: TrafficSecret,
705    sliding_window: u64,
706}
707
708impl EpochState {
709    pub const fn new(
710        read_traffic_secret: TrafficSecret,
711        write_traffic_secret: TrafficSecret,
712    ) -> Self {
713        Self {
714            send_record_seq_num: 0,
715            receive_record_seq_num: 0,
716            read_traffic_secret,
717            write_traffic_secret,
718            sliding_window: 0,
719        }
720    }
721
722    pub const fn empty() -> Self {
723        Self::new(TrafficSecret::None, TrafficSecret::None)
724    }
725
726    pub(crate) fn check_seq_num(&self, seq_num: &u64) -> Result<(), DtlsError> {
727        const WINDOW_MAX_SHIFT_BITS: u64 = 64 - 1;
728        let highest_seq_num = self.receive_record_seq_num;
729
730        if highest_seq_num > *seq_num {
731            let diff = highest_seq_num - seq_num;
732            if diff > WINDOW_MAX_SHIFT_BITS {
733                return Err(DtlsError::RejectedSequenceNumber);
734            }
735            let window_index = 1u64 << diff;
736            if self.sliding_window & window_index > 0 {
737                // Record already present
738                return Err(DtlsError::RejectedSequenceNumber);
739            }
740        } else {
741            let shift = seq_num - highest_seq_num;
742            if shift == 0 && self.sliding_window & 1 == 1 {
743                // We already received this record
744                return Err(DtlsError::RejectedSequenceNumber);
745            }
746        }
747        Ok(())
748    }
749
750    pub(crate) fn mark_received(&mut self, seq_num: &u64) {
751        let highest_seq_num = &self.receive_record_seq_num;
752        if highest_seq_num > seq_num {
753            let diff = highest_seq_num - seq_num;
754            let window_index = 1u64 << diff;
755            debug_assert!(self.sliding_window & window_index == 0);
756            self.sliding_window |= window_index;
757        } else {
758            let shift = seq_num - highest_seq_num;
759            if shift >= 64 {
760                self.sliding_window = 0;
761            } else {
762                self.sliding_window <<= shift;
763            }
764            self.receive_record_seq_num = *seq_num;
765            debug_assert!(self.sliding_window & 1 == 0);
766            self.sliding_window |= 1;
767        }
768    }
769}
770
771#[cfg(test)]
772mod tests {
773
774    use crate::{crypto::TrafficSecret, DtlsError, EpochState};
775
776    #[test]
777    pub fn reject_double_receive() {
778        let mut state = EpochState::new(TrafficSecret::None, TrafficSecret::None);
779        state.check_seq_num(&2).unwrap();
780        state.mark_received(&2);
781        assert!(matches!(
782            state.check_seq_num(&2),
783            Err(DtlsError::RejectedSequenceNumber)
784        ));
785    }
786    #[test]
787    pub fn reject_too_old_receive() {
788        let mut state = EpochState::new(TrafficSecret::None, TrafficSecret::None);
789        state.mark_received(&64);
790        assert!(matches!(
791            state.check_seq_num(&0),
792            Err(DtlsError::RejectedSequenceNumber)
793        ));
794    }
795    #[test]
796    pub fn correctly_check_after_shift() {
797        let mut state = EpochState::new(TrafficSecret::None, TrafficSecret::None);
798        state.mark_received(&20);
799        state.mark_received(&64);
800        assert!(matches!(
801            state.check_seq_num(&20),
802            Err(DtlsError::RejectedSequenceNumber)
803        ));
804    }
805}