srt_protocol/protocol/pending_connection/
connect.rs

1use std::io::ErrorKind;
2use std::{
3    net::{IpAddr, SocketAddr},
4    time::Instant,
5};
6
7use log::info;
8use ConnectError::*;
9use ConnectState::*;
10use ConnectionResult::*;
11
12use crate::{
13    connection::Connection, packet::*, protocol::handshake::Handshake, settings::ConnInitSettings,
14};
15
16use super::{
17    hsv5::{start_hsv5_initiation, StartedInitiator},
18    ConnectError, ConnectionReject, ConnectionResult,
19};
20
21#[allow(clippy::large_enum_variant)]
22#[derive(Clone)]
23enum ConnectState {
24    Configured,
25    /// keep induction packet around for retransmit
26    InductionResponseWait(Packet),
27    /// keep conclusion packet around for retransmit
28    ConclusionResponseWait(Packet, StartedInitiator),
29}
30
31impl Default for ConnectState {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl ConnectState {
38    pub fn new() -> ConnectState {
39        Configured
40    }
41}
42
43pub struct Connect {
44    remote: SocketAddr,
45    local_addr: IpAddr,
46    init_settings: ConnInitSettings,
47    state: ConnectState,
48    streamid: Option<String>,
49    starting_send_seqnum: SeqNumber,
50}
51
52impl Connect {
53    pub fn new(
54        remote: SocketAddr,
55        local_addr: IpAddr,
56        init_settings: ConnInitSettings,
57        streamid: Option<String>,
58        starting_send_seqnum: SeqNumber,
59    ) -> Self {
60        Connect {
61            remote,
62            local_addr,
63            init_settings,
64            state: ConnectState::new(),
65            streamid,
66            starting_send_seqnum,
67        }
68    }
69
70    fn on_start(&mut self) -> ConnectionResult {
71        let packet = Packet::Control(ControlPacket {
72            dest_sockid: SocketId(0),
73            timestamp: TimeStamp::from_micros(0), // TODO: this is not zero in the reference implementation
74            control_type: ControlTypes::Handshake(HandshakeControlInfo {
75                init_seq_num: self.starting_send_seqnum,
76                max_packet_size: self.init_settings.max_packet_size,
77                max_flow_size: self.init_settings.max_flow_size,
78                socket_id: self.init_settings.local_sockid,
79                shake_type: ShakeType::Induction,
80                peer_addr: self.local_addr,
81                syn_cookie: 0,
82                info: HandshakeVsInfo::V4(SocketType::Datagram),
83            }),
84        });
85        self.state = InductionResponseWait(packet.clone());
86        SendPacket((packet, self.remote))
87    }
88
89    pub fn wait_for_induction(
90        &mut self,
91        from: SocketAddr,
92        timestamp: TimeStamp,
93        info: HandshakeControlInfo,
94        now: Instant,
95    ) -> ConnectionResult {
96        match (info.shake_type, &info.info, from) {
97            (ShakeType::Induction, HandshakeVsInfo::V5 { .. }, from) if from == self.remote => {
98                let (hsv5, cm) =
99                    start_hsv5_initiation(self.init_settings.clone(), self.streamid.clone(), now);
100
101                // send back a packet with the same syn cookie
102                let packet = Packet::Control(ControlPacket {
103                    timestamp,
104                    dest_sockid: SocketId(0),
105                    control_type: ControlTypes::Handshake(HandshakeControlInfo {
106                        shake_type: ShakeType::Conclusion,
107                        socket_id: self.init_settings.local_sockid,
108                        info: hsv5,
109                        init_seq_num: self.starting_send_seqnum,
110                        ..info
111                    }),
112                });
113                self.state = ConclusionResponseWait(packet.clone(), cm);
114                SendPacket((packet, from))
115            }
116            (ShakeType::Induction, HandshakeVsInfo::V5 { .. }, from) => {
117                NotHandled(UnexpectedHost(self.remote, from))
118            }
119            (ShakeType::Induction, version, _) => {
120                NotHandled(UnsupportedProtocolVersion(version.version()))
121            }
122            (_, _, _) => NotHandled(InductionExpected(info)),
123        }
124    }
125
126    fn wait_for_conclusion(
127        &mut self,
128        from: SocketAddr,
129        now: Instant,
130        info: HandshakeControlInfo,
131        initiator: StartedInitiator,
132    ) -> ConnectionResult {
133        match (info.shake_type, info.info.version(), from) {
134            (ShakeType::Conclusion, 5, from) if from == self.remote => {
135                let settings = match initiator.finish_hsv5_initiation(&info, from, now) {
136                    Ok(s) => s,
137                    Err(rr) => return NotHandled(rr),
138                };
139
140                // TODO: no handshake retransmit packet needed? is this right? Needs testing.
141                Connected(
142                    None,
143                    Connection {
144                        settings,
145                        handshake: Handshake::Connector,
146                    },
147                )
148            }
149            (ShakeType::Conclusion, 5, from) => NotHandled(UnexpectedHost(self.remote, from)),
150            (ShakeType::Conclusion, version, _) => NotHandled(UnsupportedProtocolVersion(version)),
151            (ShakeType::Rejection(rej), _, from) if from == self.remote => {
152                Reject(None, ConnectionReject::Rejected(rej))
153            }
154            (ShakeType::Rejection(_), _, from) => NotHandled(UnexpectedHost(self.remote, from)),
155            (ShakeType::Induction, _, _) => NoAction,
156            (_, _, _) => NotHandled(ConclusionExpected(info)),
157        }
158    }
159
160    pub fn handle_packet(&mut self, packet: ReceivePacketResult, now: Instant) -> ConnectionResult {
161        use ReceivePacketError::*;
162        match packet {
163            Ok((packet, from)) => match (self.state.clone(), packet) {
164                (InductionResponseWait(_), Packet::Control(control)) => {
165                    match control.control_type {
166                        ControlTypes::Handshake(shake) => {
167                            self.wait_for_induction(from, control.timestamp, shake, now)
168                        }
169                        control_type => NotHandled(HandshakeExpected(control_type)),
170                    }
171                }
172                (ConclusionResponseWait(_, cm), Packet::Control(control)) => {
173                    match control.control_type {
174                        ControlTypes::Handshake(shake) => {
175                            self.wait_for_conclusion(from, now, shake, cm)
176                        }
177                        control_type => NotHandled(HandshakeExpected(control_type)),
178                    }
179                }
180                (_, Packet::Data(data)) => NotHandled(ControlExpected(data)),
181                (_, _) => NoAction,
182            },
183            Err(Io(error)) if error.kind() == ErrorKind::ConnectionReset => {
184                info!("ConnectionReset received, listener may not have opened the port yet...");
185                NoAction
186            }
187            Err(Io(error)) => Failure(error),
188            Err(Parse(PacketParseError::BadConnectionType(c))) => Failure(std::io::Error::new(
189                ErrorKind::ConnectionReset,
190                Parse(PacketParseError::BadConnectionType(c)),
191            )),
192            Err(Parse(e)) => NotHandled(ConnectError::ParseFailed(e)),
193        }
194    }
195
196    pub fn handle_tick(&mut self, _now: Instant) -> ConnectionResult {
197        match &self.state {
198            Configured => self.on_start(),
199            InductionResponseWait(request_packet) => {
200                SendPacket((request_packet.clone(), self.remote))
201            }
202            ConclusionResponseWait(request_packet, _) => {
203                SendPacket((request_packet.clone(), self.remote))
204            }
205        }
206    }
207}
208
209#[cfg(test)]
210mod test {
211    use std::time::Duration;
212
213    use assert_matches::assert_matches;
214    use rand::random;
215
216    use crate::{
217        options::{self, PacketCount, PacketSize},
218        protocol::pending_connection::ConnectionReject,
219    };
220
221    use super::*;
222
223    const TEST_SOCKID: SocketId = SocketId(7655);
224
225    #[test]
226    fn reject() {
227        let mut c = test_connect(Some("#!::u=test".into()));
228        c.handle_tick(Instant::now());
229
230        let first = Packet::Control(ControlPacket {
231            timestamp: TimeStamp::from_micros(0),
232            dest_sockid: TEST_SOCKID,
233            control_type: ControlTypes::Handshake(HandshakeControlInfo {
234                syn_cookie: 5554,
235                socket_id: SocketId(5678),
236                info: HandshakeVsInfo::V5(HsV5Info::default()),
237                init_seq_num: random(),
238                max_packet_size: PacketSize(8192),
239                max_flow_size: PacketCount(1234),
240                shake_type: ShakeType::Induction,
241                peer_addr: [127, 0, 0, 1].into(),
242            }),
243        });
244
245        let resp = c.handle_packet(Ok((first, test_remote())), Instant::now());
246        assert_matches!(
247            resp,
248            ConnectionResult::SendPacket((Packet::Control(ControlPacket {
249                control_type: ControlTypes::Handshake(HandshakeControlInfo {
250                    shake_type: ShakeType::Conclusion,
251                    socket_id,
252                    syn_cookie: 5554,
253                    ..
254                }), ..
255            }), _)) if socket_id == TEST_SOCKID
256        );
257
258        // send rejection
259        let rejection = Packet::Control(ControlPacket {
260            timestamp: TimeStamp::from_micros(0),
261            dest_sockid: TEST_SOCKID,
262            control_type: ControlTypes::Handshake(HandshakeControlInfo {
263                init_seq_num: random(),
264                max_packet_size: PacketSize(8192),
265                max_flow_size: PacketCount(1234),
266                shake_type: ShakeType::Rejection(RejectReason::Server(ServerRejectReason::BadMode)),
267                socket_id: SocketId(5678),
268                syn_cookie: 2222,
269                peer_addr: [127, 0, 0, 1].into(),
270                info: HandshakeVsInfo::V5(HsV5Info::default()),
271            }),
272        });
273
274        let resp = c.handle_packet(Ok((rejection, test_remote())), Instant::now());
275        assert_matches!(
276            resp,
277            ConnectionResult::Reject(
278                _,
279                ConnectionReject::Rejected(RejectReason::Server(ServerRejectReason::BadMode)),
280            )
281        );
282    }
283
284    fn test_remote() -> SocketAddr {
285        ([127, 0, 0, 1], 6666).into()
286    }
287
288    fn test_connect(sid: Option<String>) -> Connect {
289        Connect::new(
290            test_remote(),
291            [127, 0, 0, 1].into(),
292            ConnInitSettings {
293                local_sockid: TEST_SOCKID,
294                key_settings: None,
295                key_refresh: Default::default(),
296                send_latency: Duration::from_millis(20),
297                recv_latency: Duration::from_millis(20),
298                bandwidth: Default::default(),
299                statistics_interval: Duration::from_secs(1),
300                recv_buffer_size: options::PacketCount(8192),
301                send_buffer_size: options::PacketCount(8192),
302                max_packet_size: options::PacketSize(1500),
303                max_flow_size: options::PacketCount(8192),
304                peer_idle_timeout: Duration::from_secs(5),
305                too_late_packet_drop: true,
306            },
307            sid,
308            random(),
309        )
310    }
311}