1use std::io::ErrorKind;
2use std::{
3 net::{IpAddr, SocketAddr},
4 time::Instant,
5};
6
7use log::info;
8use ConnectError::*;
9use ConnectState::*;
10use ConnectionResult::*;
11
12use crate::{
13 connection::Connection, packet::*, protocol::handshake::Handshake, settings::ConnInitSettings,
14};
15
16use super::{
17 hsv5::{start_hsv5_initiation, StartedInitiator},
18 ConnectError, ConnectionReject, ConnectionResult,
19};
20
21#[allow(clippy::large_enum_variant)]
22#[derive(Clone)]
23enum ConnectState {
24 Configured,
25 InductionResponseWait(Packet),
27 ConclusionResponseWait(Packet, StartedInitiator),
29}
30
31impl Default for ConnectState {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl ConnectState {
38 pub fn new() -> ConnectState {
39 Configured
40 }
41}
42
43pub struct Connect {
44 remote: SocketAddr,
45 local_addr: IpAddr,
46 init_settings: ConnInitSettings,
47 state: ConnectState,
48 streamid: Option<String>,
49 starting_send_seqnum: SeqNumber,
50}
51
52impl Connect {
53 pub fn new(
54 remote: SocketAddr,
55 local_addr: IpAddr,
56 init_settings: ConnInitSettings,
57 streamid: Option<String>,
58 starting_send_seqnum: SeqNumber,
59 ) -> Self {
60 Connect {
61 remote,
62 local_addr,
63 init_settings,
64 state: ConnectState::new(),
65 streamid,
66 starting_send_seqnum,
67 }
68 }
69
70 fn on_start(&mut self) -> ConnectionResult {
71 let packet = Packet::Control(ControlPacket {
72 dest_sockid: SocketId(0),
73 timestamp: TimeStamp::from_micros(0), control_type: ControlTypes::Handshake(HandshakeControlInfo {
75 init_seq_num: self.starting_send_seqnum,
76 max_packet_size: self.init_settings.max_packet_size,
77 max_flow_size: self.init_settings.max_flow_size,
78 socket_id: self.init_settings.local_sockid,
79 shake_type: ShakeType::Induction,
80 peer_addr: self.local_addr,
81 syn_cookie: 0,
82 info: HandshakeVsInfo::V4(SocketType::Datagram),
83 }),
84 });
85 self.state = InductionResponseWait(packet.clone());
86 SendPacket((packet, self.remote))
87 }
88
89 pub fn wait_for_induction(
90 &mut self,
91 from: SocketAddr,
92 timestamp: TimeStamp,
93 info: HandshakeControlInfo,
94 now: Instant,
95 ) -> ConnectionResult {
96 match (info.shake_type, &info.info, from) {
97 (ShakeType::Induction, HandshakeVsInfo::V5 { .. }, from) if from == self.remote => {
98 let (hsv5, cm) =
99 start_hsv5_initiation(self.init_settings.clone(), self.streamid.clone(), now);
100
101 let packet = Packet::Control(ControlPacket {
103 timestamp,
104 dest_sockid: SocketId(0),
105 control_type: ControlTypes::Handshake(HandshakeControlInfo {
106 shake_type: ShakeType::Conclusion,
107 socket_id: self.init_settings.local_sockid,
108 info: hsv5,
109 init_seq_num: self.starting_send_seqnum,
110 ..info
111 }),
112 });
113 self.state = ConclusionResponseWait(packet.clone(), cm);
114 SendPacket((packet, from))
115 }
116 (ShakeType::Induction, HandshakeVsInfo::V5 { .. }, from) => {
117 NotHandled(UnexpectedHost(self.remote, from))
118 }
119 (ShakeType::Induction, version, _) => {
120 NotHandled(UnsupportedProtocolVersion(version.version()))
121 }
122 (_, _, _) => NotHandled(InductionExpected(info)),
123 }
124 }
125
126 fn wait_for_conclusion(
127 &mut self,
128 from: SocketAddr,
129 now: Instant,
130 info: HandshakeControlInfo,
131 initiator: StartedInitiator,
132 ) -> ConnectionResult {
133 match (info.shake_type, info.info.version(), from) {
134 (ShakeType::Conclusion, 5, from) if from == self.remote => {
135 let settings = match initiator.finish_hsv5_initiation(&info, from, now) {
136 Ok(s) => s,
137 Err(rr) => return NotHandled(rr),
138 };
139
140 Connected(
142 None,
143 Connection {
144 settings,
145 handshake: Handshake::Connector,
146 },
147 )
148 }
149 (ShakeType::Conclusion, 5, from) => NotHandled(UnexpectedHost(self.remote, from)),
150 (ShakeType::Conclusion, version, _) => NotHandled(UnsupportedProtocolVersion(version)),
151 (ShakeType::Rejection(rej), _, from) if from == self.remote => {
152 Reject(None, ConnectionReject::Rejected(rej))
153 }
154 (ShakeType::Rejection(_), _, from) => NotHandled(UnexpectedHost(self.remote, from)),
155 (ShakeType::Induction, _, _) => NoAction,
156 (_, _, _) => NotHandled(ConclusionExpected(info)),
157 }
158 }
159
160 pub fn handle_packet(&mut self, packet: ReceivePacketResult, now: Instant) -> ConnectionResult {
161 use ReceivePacketError::*;
162 match packet {
163 Ok((packet, from)) => match (self.state.clone(), packet) {
164 (InductionResponseWait(_), Packet::Control(control)) => {
165 match control.control_type {
166 ControlTypes::Handshake(shake) => {
167 self.wait_for_induction(from, control.timestamp, shake, now)
168 }
169 control_type => NotHandled(HandshakeExpected(control_type)),
170 }
171 }
172 (ConclusionResponseWait(_, cm), Packet::Control(control)) => {
173 match control.control_type {
174 ControlTypes::Handshake(shake) => {
175 self.wait_for_conclusion(from, now, shake, cm)
176 }
177 control_type => NotHandled(HandshakeExpected(control_type)),
178 }
179 }
180 (_, Packet::Data(data)) => NotHandled(ControlExpected(data)),
181 (_, _) => NoAction,
182 },
183 Err(Io(error)) if error.kind() == ErrorKind::ConnectionReset => {
184 info!("ConnectionReset received, listener may not have opened the port yet...");
185 NoAction
186 }
187 Err(Io(error)) => Failure(error),
188 Err(Parse(PacketParseError::BadConnectionType(c))) => Failure(std::io::Error::new(
189 ErrorKind::ConnectionReset,
190 Parse(PacketParseError::BadConnectionType(c)),
191 )),
192 Err(Parse(e)) => NotHandled(ConnectError::ParseFailed(e)),
193 }
194 }
195
196 pub fn handle_tick(&mut self, _now: Instant) -> ConnectionResult {
197 match &self.state {
198 Configured => self.on_start(),
199 InductionResponseWait(request_packet) => {
200 SendPacket((request_packet.clone(), self.remote))
201 }
202 ConclusionResponseWait(request_packet, _) => {
203 SendPacket((request_packet.clone(), self.remote))
204 }
205 }
206 }
207}
208
209#[cfg(test)]
210mod test {
211 use std::time::Duration;
212
213 use assert_matches::assert_matches;
214 use rand::random;
215
216 use crate::{
217 options::{self, PacketCount, PacketSize},
218 protocol::pending_connection::ConnectionReject,
219 };
220
221 use super::*;
222
223 const TEST_SOCKID: SocketId = SocketId(7655);
224
225 #[test]
226 fn reject() {
227 let mut c = test_connect(Some("#!::u=test".into()));
228 c.handle_tick(Instant::now());
229
230 let first = Packet::Control(ControlPacket {
231 timestamp: TimeStamp::from_micros(0),
232 dest_sockid: TEST_SOCKID,
233 control_type: ControlTypes::Handshake(HandshakeControlInfo {
234 syn_cookie: 5554,
235 socket_id: SocketId(5678),
236 info: HandshakeVsInfo::V5(HsV5Info::default()),
237 init_seq_num: random(),
238 max_packet_size: PacketSize(8192),
239 max_flow_size: PacketCount(1234),
240 shake_type: ShakeType::Induction,
241 peer_addr: [127, 0, 0, 1].into(),
242 }),
243 });
244
245 let resp = c.handle_packet(Ok((first, test_remote())), Instant::now());
246 assert_matches!(
247 resp,
248 ConnectionResult::SendPacket((Packet::Control(ControlPacket {
249 control_type: ControlTypes::Handshake(HandshakeControlInfo {
250 shake_type: ShakeType::Conclusion,
251 socket_id,
252 syn_cookie: 5554,
253 ..
254 }), ..
255 }), _)) if socket_id == TEST_SOCKID
256 );
257
258 let rejection = Packet::Control(ControlPacket {
260 timestamp: TimeStamp::from_micros(0),
261 dest_sockid: TEST_SOCKID,
262 control_type: ControlTypes::Handshake(HandshakeControlInfo {
263 init_seq_num: random(),
264 max_packet_size: PacketSize(8192),
265 max_flow_size: PacketCount(1234),
266 shake_type: ShakeType::Rejection(RejectReason::Server(ServerRejectReason::BadMode)),
267 socket_id: SocketId(5678),
268 syn_cookie: 2222,
269 peer_addr: [127, 0, 0, 1].into(),
270 info: HandshakeVsInfo::V5(HsV5Info::default()),
271 }),
272 });
273
274 let resp = c.handle_packet(Ok((rejection, test_remote())), Instant::now());
275 assert_matches!(
276 resp,
277 ConnectionResult::Reject(
278 _,
279 ConnectionReject::Rejected(RejectReason::Server(ServerRejectReason::BadMode)),
280 )
281 );
282 }
283
284 fn test_remote() -> SocketAddr {
285 ([127, 0, 0, 1], 6666).into()
286 }
287
288 fn test_connect(sid: Option<String>) -> Connect {
289 Connect::new(
290 test_remote(),
291 [127, 0, 0, 1].into(),
292 ConnInitSettings {
293 local_sockid: TEST_SOCKID,
294 key_settings: None,
295 key_refresh: Default::default(),
296 send_latency: Duration::from_millis(20),
297 recv_latency: Duration::from_millis(20),
298 bandwidth: Default::default(),
299 statistics_interval: Duration::from_secs(1),
300 recv_buffer_size: options::PacketCount(8192),
301 send_buffer_size: options::PacketCount(8192),
302 max_packet_size: options::PacketSize(1500),
303 max_flow_size: options::PacketCount(8192),
304 peer_idle_timeout: Duration::from_secs(5),
305 too_late_packet_drop: true,
306 },
307 sid,
308 random(),
309 )
310 }
311}