srt_protocol/protocol/pending_connection/
listen.rs

1use std::{convert::TryInto, net::SocketAddr, time::Instant};
2
3use crate::{packet::*, protocol::handshake::Handshake, settings::*};
4
5use super::{
6    cookie::gen_cookie, hsv5::gen_access_control_response, hsv5::GenHsv5Result,
7    AccessControlRequest, AccessControlResponse, ConnectError, Connection, ConnectionReject,
8    ConnectionResult,
9};
10
11use ConnectionResult::*;
12use ListenState::*;
13
14#[derive(Debug)]
15pub struct Listen {
16    init_settings: ConnInitSettings,
17    state: ListenState,
18    enable_access_control: bool,
19}
20
21#[derive(Clone, Debug)]
22pub struct ConclusionWaitState {
23    from: SocketAddr,
24    cookie: i32,
25    induction_response: Packet,
26    induction_time: Instant,
27}
28
29#[derive(Clone, Debug)]
30#[allow(clippy::large_enum_variant)]
31enum ListenState {
32    InductionWait,
33    ConclusionWait(ConclusionWaitState),
34    AccessControlRequested(
35        ConclusionWaitState,
36        TimeStamp,
37        HandshakeControlInfo,
38        HsV5Info,
39    ),
40}
41
42impl Listen {
43    pub fn new(init_settings: ConnInitSettings, enable_access_control: bool) -> Listen {
44        Listen {
45            state: InductionWait,
46            init_settings,
47            enable_access_control,
48        }
49    }
50
51    pub fn settings(&self) -> &ConnInitSettings {
52        &self.init_settings
53    }
54
55    pub fn handle_packet(&mut self, now: Instant, packet: ReceivePacketResult) -> ConnectionResult {
56        use ReceivePacketError::*;
57        match packet {
58            Ok((packet, from)) => match packet {
59                Packet::Control(control) => self.handle_control_packets(now, from, control),
60                Packet::Data(data) => NotHandled(ConnectError::ControlExpected(data)),
61            },
62            Err(Io(error)) => Failure(error),
63            Err(Parse(e)) => NotHandled(ConnectError::ParseFailed(e)),
64        }
65    }
66
67    pub fn handle_access_control_response(
68        &mut self,
69        now: Instant,
70        response: AccessControlResponse,
71    ) -> ConnectionResult {
72        match self.state.clone() {
73            // TODO: something other than ExpectedHsReq
74            InductionWait | ConclusionWait(_) => NotHandled(ConnectError::ExpectedHsReq),
75            AccessControlRequested(state, timestamp, shake, info) => {
76                use AccessControlResponse::*;
77                match response {
78                    Accepted(key_settings) => {
79                        self.accept_connection(now, &state, timestamp, shake, info, key_settings)
80                    }
81                    Rejected(rr) => self.make_rejection(
82                        &shake,
83                        state.from,
84                        timestamp,
85                        ConnectionReject::Rejecting(rr),
86                    ),
87                    Dropped => self.make_rejection(
88                        &shake,
89                        state.from,
90                        timestamp,
91                        ConnectionReject::Rejecting(RejectReason::Core(CoreRejectReason::Peer)),
92                    ),
93                }
94            }
95        }
96    }
97
98    pub fn handle_timer(&self, _now: Instant) -> ConnectionResult {
99        NoAction
100    }
101
102    fn handle_control_packets(
103        &mut self,
104        now: Instant,
105        from: SocketAddr,
106        control: ControlPacket,
107    ) -> ConnectionResult {
108        match (self.state.clone(), control.control_type) {
109            (InductionWait, ControlTypes::Handshake(shake)) => {
110                self.wait_for_induction(from, control.timestamp, shake, now)
111            }
112            (ConclusionWait(state), ControlTypes::Handshake(shake)) => self.wait_for_conclusion(
113                now,
114                from,
115                control.dest_sockid,
116                control.timestamp,
117                state,
118                shake,
119            ),
120            (AccessControlRequested(_, _, _, _), _) => {
121                NotHandled(ConnectError::ExpectedAccessControlResponse)
122            }
123            (InductionWait, control_type) | (ConclusionWait(_), control_type) => {
124                NotHandled(ConnectError::HandshakeExpected(control_type))
125            }
126        }
127    }
128
129    fn wait_for_induction(
130        &mut self,
131        from: SocketAddr,
132        timestamp: TimeStamp,
133        shake: HandshakeControlInfo,
134        now: Instant,
135    ) -> ConnectionResult {
136        match shake.shake_type {
137            ShakeType::Induction => {
138                // https://tools.ietf.org/html/draft-gg-udt-03#page-9
139                // When the server first receives the connection request from a client,
140                // it generates a cookie value according to the client address and a
141                // secret key and sends it back to the client. The client must then send
142                // back the same cookie to the server.
143
144                // generate the cookie, which is just a hash of the address + time
145                let cookie = gen_cookie(&from);
146
147                // we expect HSv5, so upgrade it
148                // construct a packet to send back
149                let induction_response = Packet::Control(ControlPacket {
150                    timestamp,
151                    dest_sockid: shake.socket_id,
152                    control_type: ControlTypes::Handshake(HandshakeControlInfo {
153                        syn_cookie: cookie,
154                        socket_id: self.init_settings.local_sockid,
155                        info: HandshakeVsInfo::V5(HsV5Info::default()),
156                        ..shake
157                    }),
158                });
159
160                // save induction message for potential later retransmit
161                let save_induction_response = induction_response.clone();
162                self.state = ConclusionWait(ConclusionWaitState {
163                    from,
164                    cookie,
165                    induction_response: save_induction_response,
166                    induction_time: now,
167                });
168                SendPacket((induction_response, from))
169            }
170            _ => NotHandled(ConnectError::InductionExpected(shake)),
171        }
172    }
173
174    fn wait_for_conclusion(
175        &mut self,
176        now: Instant,
177        from: SocketAddr,
178        local_socket_id: SocketId,
179        timestamp: TimeStamp,
180        state: ConclusionWaitState,
181        shake: HandshakeControlInfo,
182    ) -> ConnectionResult {
183        // https://tools.ietf.org/html/draft-gg-udt-03#page-10
184        // The server, when receiving a handshake packet and the correct cookie,
185        // compares the packet size and maximum window size with its own values
186        // and set its own values as the smaller ones. The result values are
187        // also sent back to the client by a response handshake packet, together
188        // with the server's version and initial sequence number. The server is
189        // ready for sending/receiving data right after this step is finished.
190        // However, it must send back response packet as long as it receives any
191        // further handshakes from the same client.
192
193        const VERSION_5: u32 = 5;
194
195        match (shake.shake_type, shake.info.version(), shake.syn_cookie) {
196            (ShakeType::Induction, _, _) => SendPacket((state.induction_response, from)),
197            // first induction received, wait for response (with cookie)
198            (ShakeType::Conclusion, VERSION_5, syn_cookie) if syn_cookie == state.cookie => {
199                let incoming = match &shake.info {
200                    HandshakeVsInfo::V5(hs) => hs,
201                    _ => {
202                        let r = ConnectionReject::Rejecting(
203                            // TODO: this error is technically reserved for access control handlers, as the ref impl supports hsv4+5, while we only support 5
204                            ServerRejectReason::Version.into(),
205                        );
206                        return self.make_rejection(&shake, from, timestamp, r);
207                    }
208                }
209                .clone();
210
211                if self.enable_access_control {
212                    self.request_access(from, local_socket_id, timestamp, state, shake, incoming)
213                } else {
214                    let key_settings = self.settings().key_settings.clone();
215                    self.accept_connection(now, &state, timestamp, shake, incoming, key_settings)
216                }
217            }
218            (ShakeType::Conclusion, VERSION_5, syn_cookie) => NotHandled(
219                ConnectError::InvalidHandshakeCookie(state.cookie, syn_cookie),
220            ),
221            (ShakeType::Conclusion, version, _) => {
222                NotHandled(ConnectError::UnsupportedProtocolVersion(version))
223            }
224            (_, _, _) => NotHandled(ConnectError::ConclusionExpected(shake)),
225        }
226    }
227
228    fn request_access(
229        &mut self,
230        remote: SocketAddr,
231        local_socket_id: SocketId,
232        timestamp: TimeStamp,
233        state: ConclusionWaitState,
234        shake: HandshakeControlInfo,
235        incoming: HsV5Info,
236    ) -> ConnectionResult {
237        // TODO: handle StreamId parsing error
238        let stream_id = incoming.sid.clone().and_then(|s| s.try_into().ok());
239        let remote_socket_id = shake.socket_id;
240        let key_size = incoming.key_size;
241
242        self.state = AccessControlRequested(state, timestamp, shake, incoming);
243
244        RequestAccess(AccessControlRequest {
245            local_socket_id,
246            remote,
247            remote_socket_id,
248            stream_id,
249            key_size,
250        })
251    }
252
253    fn accept_connection(
254        &mut self,
255        now: Instant,
256        state: &ConclusionWaitState,
257        timestamp: TimeStamp,
258        shake: HandshakeControlInfo,
259        info: HsV5Info,
260        key_settings: Option<KeySettings>,
261    ) -> ConnectionResult {
262        let response = gen_access_control_response(
263            now,
264            &mut self.init_settings,
265            state.from,
266            state.induction_time,
267            shake.clone(),
268            info,
269            key_settings,
270        );
271        let (hsv5, settings) = match response {
272            GenHsv5Result::Accept(h, c) => (h, c),
273            GenHsv5Result::NotHandled(e) => return NotHandled(e),
274            GenHsv5Result::Reject(r) => {
275                return self.make_rejection(&shake, state.from, timestamp, r);
276            }
277        };
278
279        let resp_handshake = ControlPacket {
280            timestamp,
281            dest_sockid: shake.socket_id,
282            control_type: ControlTypes::Handshake(HandshakeControlInfo {
283                syn_cookie: state.cookie,
284                socket_id: self.init_settings.local_sockid,
285                info: hsv5,
286                shake_type: ShakeType::Conclusion,
287                ..shake // TODO: this will pass peer wrong
288            }),
289        };
290
291        // finish the connection
292        Connected(
293            Some((resp_handshake.clone().into(), state.from)),
294            Connection {
295                settings,
296                handshake: Handshake::Listener(resp_handshake.control_type),
297            },
298        )
299    }
300
301    fn make_rejection(
302        &mut self,
303        response_to: &HandshakeControlInfo,
304        from: SocketAddr,
305        timestamp: TimeStamp,
306        r: ConnectionReject,
307    ) -> ConnectionResult {
308        self.state = InductionWait;
309        Reject(
310            Some((
311                ControlPacket {
312                    timestamp,
313                    dest_sockid: response_to.socket_id,
314                    control_type: ControlTypes::Handshake(HandshakeControlInfo {
315                        shake_type: ShakeType::Rejection(r.reason()),
316                        socket_id: self.init_settings.local_sockid,
317                        ..response_to.clone()
318                    }),
319                }
320                .into(),
321                from,
322            )),
323            r,
324        )
325    }
326}
327
328#[cfg(test)]
329mod test {
330    use std::{
331        net::{IpAddr, Ipv4Addr},
332        time::Duration,
333    };
334
335    use assert_matches::assert_matches;
336    use bytes::Bytes;
337    use rand::random;
338
339    use crate::options::*;
340
341    use super::*;
342
343    fn conn_addr() -> SocketAddr {
344        SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8765)
345    }
346
347    fn test_listen() -> Listen {
348        Listen::new(ConnInitSettings::default(), false)
349    }
350
351    fn test_induction() -> HandshakeControlInfo {
352        HandshakeControlInfo {
353            init_seq_num: random(),
354            max_packet_size: PacketSize(1316),
355            max_flow_size: PacketCount(256_000),
356            shake_type: ShakeType::Induction,
357            socket_id: random(),
358            syn_cookie: 0,
359            peer_addr: IpAddr::from([127, 0, 0, 1]),
360            info: HandshakeVsInfo::V5(HsV5Info::default()),
361        }
362    }
363
364    fn test_conclusion() -> HandshakeControlInfo {
365        HandshakeControlInfo {
366            init_seq_num: random(),
367            max_packet_size: PacketSize(1316),
368            max_flow_size: PacketCount(256_000),
369            shake_type: ShakeType::Conclusion,
370            socket_id: random(),
371            syn_cookie: gen_cookie(&conn_addr()),
372            peer_addr: IpAddr::from([127, 0, 0, 1]),
373            info: HandshakeVsInfo::V5(HsV5Info {
374                key_size: KeySize::Unspecified,
375                ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
376                    version: SrtVersion::CURRENT,
377                    flags: SrtShakeFlags::SUPPORTED,
378                    send_latency: Duration::from_secs(1),
379                    recv_latency: Duration::from_secs(2),
380                })),
381                ext_km: None,
382                ext_group: None,
383                sid: None,
384            }),
385        }
386    }
387
388    fn build_hs_pack(i: HandshakeControlInfo) -> Packet {
389        Packet::Control(ControlPacket {
390            timestamp: TimeStamp::from_micros(0),
391            dest_sockid: random(),
392            control_type: ControlTypes::Handshake(i),
393        })
394    }
395
396    #[test]
397    fn correct() {
398        let mut l = test_listen();
399
400        let resp = l.handle_packet(
401            Instant::now(),
402            Ok((build_hs_pack(test_induction()), conn_addr())),
403        );
404        assert_matches!(resp, SendPacket(_));
405
406        let resp = l.handle_packet(
407            Instant::now(),
408            Ok((build_hs_pack(test_conclusion()), conn_addr())),
409        );
410        // make sure it returns hs_ext
411        assert_matches!(
412            resp,
413            Connected(
414                Some(_),
415                Connection {
416                    handshake: Handshake::Listener(ControlTypes::Handshake(HandshakeControlInfo {
417                        info: HandshakeVsInfo::V5(HsV5Info {
418                            ext_hs: Some(_),
419                            ..
420                        }),
421                        ..
422                    })),
423                    ..
424                },
425            )
426        );
427    }
428
429    #[test]
430    fn send_data_packet() {
431        let mut l = test_listen();
432
433        let dp = DataPacket {
434            seq_number: random(),
435            message_loc: PacketLocation::ONLY,
436            in_order_delivery: false,
437            encryption: DataEncryption::None,
438            retransmitted: false,
439            message_number: random(),
440            timestamp: TimeStamp::from_micros(0),
441            dest_sockid: random(),
442            payload: Bytes::from(&b"asdf"[..]),
443        };
444        assert_matches!(
445            l.handle_packet(Instant::now(), Ok(( Packet::Data(dp.clone()), conn_addr()))),
446            NotHandled(ConnectError::ControlExpected(d)) if d == dp
447        );
448    }
449
450    #[test]
451    fn send_ack2() {
452        let mut l = test_listen();
453
454        let a2 = ControlTypes::Ack2(FullAckSeqNumber::new(random::<u32>() + 1).unwrap());
455        assert_matches!(
456            l.handle_packet(Instant::now(),
457                Ok((
458                    Packet::Control(ControlPacket {
459                        timestamp: TimeStamp::from_micros(0),
460                        dest_sockid: random(),
461                        control_type: a2.clone()
462                    }),
463                    conn_addr()
464                )),
465            ),
466            NotHandled(ConnectError::HandshakeExpected(pack)) if pack == a2
467        );
468    }
469
470    #[test]
471    fn send_wrong_handshake() {
472        let mut l = test_listen();
473
474        // listen expects an induction first, send a conclustion first
475
476        let shake = test_conclusion();
477        assert_matches!(
478            l.handle_packet(Instant::now(), Ok((
479                build_hs_pack(shake.clone()),
480                conn_addr()
481            ))),
482            NotHandled(ConnectError::InductionExpected(s)) if s == shake
483        );
484    }
485
486    #[test]
487    fn send_induction_twice() {
488        let mut l = test_listen();
489
490        // send a rendezvous handshake after an induction
491        let resp = l.handle_packet(
492            Instant::now(),
493            Ok((build_hs_pack(test_induction()), conn_addr())),
494        );
495        assert_matches!(resp, SendPacket(_));
496
497        let mut shake = test_induction();
498        shake.shake_type = ShakeType::Waveahand;
499        assert_matches!(
500            l.handle_packet(Instant::now(), Ok((
501                build_hs_pack(shake.clone()),
502                conn_addr()
503            ))),
504            NotHandled(ConnectError::ConclusionExpected(nc)) if nc == shake
505        )
506    }
507
508    #[test]
509    fn send_v4_conclusion() {
510        let mut l = test_listen();
511
512        let resp = l.handle_packet(
513            Instant::now(),
514            Ok((build_hs_pack(test_induction()), conn_addr())),
515        );
516        assert_matches!(resp, SendPacket(_));
517
518        let mut c = test_conclusion();
519        c.info = HandshakeVsInfo::V4(SocketType::Datagram);
520
521        let resp = l.handle_packet(Instant::now(), Ok((build_hs_pack(c), conn_addr())));
522
523        assert_matches!(
524            resp,
525            NotHandled(ConnectError::UnsupportedProtocolVersion(4))
526        );
527    }
528
529    #[test]
530    fn send_no_ext_hs_conclusion() {
531        let mut l = test_listen();
532
533        let resp = l.handle_packet(
534            Instant::now(),
535            Ok((build_hs_pack(test_induction()), conn_addr())),
536        );
537        assert_matches!(resp, SendPacket(_));
538
539        let mut c = test_conclusion();
540        c.info = HandshakeVsInfo::V5(HsV5Info::default());
541
542        let resp = l.handle_packet(Instant::now(), Ok((build_hs_pack(c), conn_addr())));
543
544        assert_matches!(resp, NotHandled(ConnectError::ExpectedExtFlags));
545    }
546
547    #[test]
548    fn reject() {
549        let mut l = Listen::new(ConnInitSettings::default(), true);
550
551        let resp = l.handle_packet(
552            Instant::now(),
553            Ok((build_hs_pack(test_induction()), conn_addr())),
554        );
555        assert_matches!(resp, SendPacket(_));
556
557        let resp = l.handle_packet(
558            Instant::now(),
559            Ok((build_hs_pack(test_conclusion()), conn_addr())),
560        );
561        assert_matches!(resp, RequestAccess(_));
562
563        let resp = l.handle_access_control_response(
564            Instant::now(),
565            AccessControlResponse::Rejected(RejectReason::Server(ServerRejectReason::Overload)),
566        );
567        assert_matches!(
568            resp,
569            Reject(
570                _,
571                ConnectionReject::Rejecting(RejectReason::Server(ServerRejectReason::Overload)),
572            )
573        );
574    }
575
576    #[test]
577    fn advertise_key_size() {
578        let mut l = Listen::new(ConnInitSettings::default(), true);
579
580        l.handle_packet(
581            Instant::now(),
582            Ok((build_hs_pack(test_induction()), conn_addr())),
583        );
584
585        let hs_key_size = KeySize::AES256;
586
587        let shake = HandshakeControlInfo {
588            info: HandshakeVsInfo::V5(HsV5Info {
589                key_size: hs_key_size,
590                ext_hs: Some(SrtControlPacket::HandshakeRequest(SrtHandshake {
591                    version: SrtVersion::CURRENT,
592                    flags: SrtShakeFlags::SUPPORTED,
593                    send_latency: Duration::from_secs(1),
594                    recv_latency: Duration::from_secs(2),
595                })),
596                ext_km: None,
597                ext_group: None,
598                sid: None,
599            }),
600            ..test_conclusion()
601        };
602
603        let hs_packet = Packet::Control(ControlPacket {
604            timestamp: TimeStamp::from_micros(0),
605            dest_sockid: random(),
606            control_type: ControlTypes::Handshake(shake),
607        });
608
609        let RequestAccess(request_access) =
610            l.handle_packet(Instant::now(), Ok((hs_packet, conn_addr())))
611        else {
612            panic!("expected a ConnectionResult::RequestAccess");
613        };
614
615        assert_eq!(request_access.key_size, hs_key_size);
616    }
617}