1use std::{
17 io,
18 net::SocketAddr,
19 pin,
20 sync::{Arc, Mutex},
21 task::ready,
22 time::Duration,
23};
24
25use ana_gotatun::{
26 noise::{Tunn, TunnResult, errors::WireGuardError, rate_limiter::RateLimiter},
27 packet::{Packet, WgKind},
28 x25519::{self},
29};
30use bytes::{Bytes, BytesMut};
31use tokio::{
32 select,
33 sync::mpsc::{self, error::TrySendError},
34 task::JoinHandle,
35 time::Interval,
36};
37use tracing::instrument;
38use zerocopy::IntoBytes as _;
39
40const UDP_DATAGRAM_BUFFER_SIZE: usize = 65535;
41const HANDSHAKE_RATE_LIMIT: u64 = 20;
42
43#[derive(Debug, thiserror::Error)]
45pub enum SnapTunNgSocketError {
46 #[error("i/o error: {0}")]
48 IoError(#[from] std::io::Error),
49 #[error("receive queue closed")]
51 ReceiveQueueClosed,
52 #[error("initial handshake timed out")]
54 InitialHandshakeTimeout,
55 #[error("wireguard error: {0:?}")]
57 WireguardError(WireGuardError),
58}
59
60struct SnapTunNgClientDriver {
61 pub tunn: Arc<Mutex<Tunn>>,
62 pub underlay_socket: Arc<tokio::net::UdpSocket>,
63 pub dataplane_address: SocketAddr,
64 pub update_timers_interval: Interval,
65 pub packet_sender: mpsc::Sender<BytesMut>,
66 pub local_sockaddr: Option<SocketAddr>,
67}
68
69fn to_bytes(wg: WgKind) -> Packet<[u8]> {
70 match wg {
71 WgKind::HandshakeInit(p) => p.into_bytes(),
72 WgKind::HandshakeResp(p) => p.into_bytes(),
73 WgKind::CookieReply(p) => p.into_bytes(),
74 WgKind::Data(p) => p.into_bytes(),
75 }
76}
77
78impl SnapTunNgClientDriver {
79 fn new(
80 tunn: Arc<Mutex<Tunn>>,
81 underlay_socket: Arc<tokio::net::UdpSocket>,
82 dataplane_address: SocketAddr,
83 packet_sender: mpsc::Sender<BytesMut>,
84 ) -> Self {
85 let update_timers_interval = tokio::time::interval_at(
86 tokio::time::Instant::now() + Duration::from_millis(250),
87 Duration::from_millis(250),
88 );
89 Self {
90 tunn,
91 underlay_socket,
92 dataplane_address,
93 update_timers_interval,
94 packet_sender,
95 local_sockaddr: None,
96 }
97 }
98
99 #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
100 async fn initial_connection(&mut self) -> Result<SocketAddr, SnapTunNgSocketError> {
101 let handshake_init = self.tunn.lock().unwrap().format_handshake_initiation(false);
102 if let Some(wg_init) = handshake_init
103 && let Err(e) = self
104 .underlay_socket
105 .send_to(
106 to_bytes(WgKind::HandshakeInit(wg_init)).as_bytes(),
107 self.dataplane_address,
108 )
109 .await
110 {
111 return Err(SnapTunNgSocketError::IoError(e));
112 }
113 loop {
115 self.drive_once().await?;
116 if let Some(sockaddr) = self.tunn.lock().unwrap().get_initiator_remote_sockaddr() {
117 self.local_sockaddr = Some(sockaddr);
118 tracing::debug!(socket_addr=?sockaddr, "handshake completed, socket address assigned");
119 return Ok(sockaddr);
120 }
121 }
122 }
123
124 #[instrument(name = "st-client", skip(self), fields(socket_addr= ?self.local_sockaddr))]
125 async fn main_loop(mut self) {
126 loop {
127 let result = self.drive_once().await;
128 if let Err(ref e) = result {
129 tracing::error!(err=?e, "error driving tunnel");
130 }
131 if let Err(SnapTunNgSocketError::ReceiveQueueClosed) = result {
132 tracing::info!("receive queue closed, snap tunnel driver shutting down");
133 return;
134 }
135 }
136 }
137
138 async fn drive_once(&mut self) -> Result<(), SnapTunNgSocketError> {
139 let mut buf = BytesMut::zeroed(UDP_DATAGRAM_BUFFER_SIZE);
140 select! {
141 _ = self.update_timers_interval.tick() => {
142 let p = match self.tunn.lock().unwrap().update_timers() {
143 Ok(Some(wg)) => { Some(wg) },
144 Ok(None) => None,
145 Err(WireGuardError::ConnectionExpired) => {
146 return Err(SnapTunNgSocketError::InitialHandshakeTimeout);
147 }
148 Err(e) => {
149 tracing::error!(err=?e, "error updating timers on tunnel");
150 None
151 }
152 };
153 if let Some(wg) = p && let Err(e) = self.underlay_socket.send_to(to_bytes(wg).as_bytes(), self.dataplane_address).await {
154 return Err(SnapTunNgSocketError::IoError(e));
155 }
156 },
157 recv = self.underlay_socket.recv_from(&mut buf) => {
158 let (n, sender_addr) = match recv {
159 Ok((n, sender_addr)) => (n, sender_addr),
160 Err(e) => {
161 return Err(SnapTunNgSocketError::IoError(e));
162 }
163 };
164 if sender_addr != self.dataplane_address {
165 return Ok(());
167 }
168 buf.truncate(n);
169 let packet: Packet<[u8]> = Packet::from_bytes(buf);
170 let wg = packet.try_into_wg().expect("this needs to be handled");
171 let result = self.tunn.lock().unwrap().handle_incoming_packet(wg);
173 let ps = match result {
174 TunnResult::Done => None,
175 TunnResult::Err(e) => {
176 return Err(SnapTunNgSocketError::WireguardError(e));
177 }
178 TunnResult::WriteToNetwork(p) => {
179 let queued_packets = self.tunn.lock().unwrap().get_queued_packets().collect::<Vec<_>>();
181 let packets = std::iter::once(p).chain(queued_packets.into_iter());
182 Some(packets)
183
184 }
185 TunnResult::WriteToTunnel(mut p) => {
186 let buf = p.buf_mut().to_owned();
187
188 if !buf.is_empty() {
190 match self.packet_sender.try_send(buf) {
191 Ok(()) => {},
192 Err(TrySendError::Full(_)) => {
193 tracing::error!("receive channel is full, dropping packet");
194 }
195 Err(_) => {
196 return Err(SnapTunNgSocketError::ReceiveQueueClosed);
198 }
199 }
200 }
201 None
202 }
203 };
204 if let Some(packets) = ps {
205 for p in packets {
206 if let Err(e) = self.underlay_socket.send_to(to_bytes(p).as_bytes(), self.dataplane_address).await {
207 return Err(SnapTunNgSocketError::IoError(e));
208 }
209 }
210 }
211 }
212 }
213 Ok(())
214 }
215}
216
217pub struct SnapTunNgSocket {
219 tunn: Arc<Mutex<Tunn>>,
220 underlay_socket: Arc<tokio::net::UdpSocket>,
221 dataplane_address: SocketAddr,
222 local_sockaddr: SocketAddr,
223 receive_queue: tokio::sync::Mutex<mpsc::Receiver<BytesMut>>,
224 driver_task: JoinHandle<()>,
227}
228
229impl Drop for SnapTunNgSocket {
230 fn drop(&mut self) {
231 self.driver_task.abort();
232 }
233}
234
235impl SnapTunNgSocket {
236 pub async fn new(
247 static_private: x25519::StaticSecret,
248 peer_public: x25519::PublicKey,
249 underlay_socket: Arc<tokio::net::UdpSocket>,
250 dataplane_address: SocketAddr,
251 receive_queue_capacity: usize,
252 ) -> Result<Self, SnapTunNgSocketError> {
253 let local_public = x25519::PublicKey::from(&static_private);
254 let tunn = Arc::new(Mutex::new(Tunn::new(
255 static_private,
256 peer_public,
257 None,
258 None,
259 0,
260 Arc::new(RateLimiter::new(&local_public, HANDSHAKE_RATE_LIMIT)),
261 dataplane_address,
262 )));
263 let (packet_sender, packet_receiver) = mpsc::channel(receive_queue_capacity);
264 let mut driver = SnapTunNgClientDriver::new(
265 tunn.clone(),
266 underlay_socket.clone(),
267 dataplane_address,
268 packet_sender,
269 );
270 let socket_addr = driver.initial_connection().await?;
271 Ok(Self {
272 tunn,
273 underlay_socket,
274 dataplane_address,
275 local_sockaddr: socket_addr,
276 receive_queue: tokio::sync::Mutex::new(packet_receiver),
279 driver_task: tokio::spawn(driver.main_loop()),
280 })
281 }
282
283 #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= payload.len()))]
285 pub async fn send(&self, payload: BytesMut) -> io::Result<()> {
286 let packet: Packet<[u8]> = Packet::from_bytes(payload);
287 let encapsulated_packet = self.tunn.lock().unwrap().handle_outgoing_packet(packet);
288 match encapsulated_packet {
289 Some(wg) => {
290 let bytes = match wg {
291 WgKind::HandshakeInit(p) => p.into_bytes(),
292 WgKind::HandshakeResp(p) => p.into_bytes(),
293 WgKind::CookieReply(p) => p.into_bytes(),
294 WgKind::Data(p) => p.into_bytes(),
295 };
296 tracing::trace!(dataplane_address=?self.dataplane_address, "sending packet");
297 self.underlay_socket
298 .send_to(bytes.as_bytes(), self.dataplane_address)
299 .await?;
300 Ok(())
301 }
302 None => {
303 tracing::trace!("handshake ongoing, queueing packet");
307 Ok(())
308 }
309 }
310 }
311
312 #[instrument(name = "st-client", skip_all, fields(socket_addr= ?self.local_sockaddr, payload_len= payload.len()))]
314 pub fn try_send(&self, payload: BytesMut) -> io::Result<()> {
315 let packet: Packet<[u8]> = Packet::from_bytes(payload);
316 match self.tunn.lock().unwrap().handle_outgoing_packet(packet) {
317 Some(wg) => {
318 let bytes = match wg {
319 WgKind::HandshakeInit(p) => p.into_bytes(),
320 WgKind::HandshakeResp(p) => p.into_bytes(),
321 WgKind::CookieReply(p) => p.into_bytes(),
322 WgKind::Data(p) => p.into_bytes(),
323 };
324 tracing::trace!(dataplane_address=?self.dataplane_address, "trying to send packet");
325 self.underlay_socket
326 .try_send_to(bytes.as_bytes(), self.dataplane_address)?;
327 Ok(())
328 }
329 None => {
330 Ok(())
334 }
335 }
336 }
337
338 pub async fn recv(&self) -> Result<Bytes, SnapTunNgSocketError> {
340 match self.receive_queue.lock().await.recv().await {
341 Some(packet) => Ok(packet.into()),
342 None => Err(SnapTunNgSocketError::ReceiveQueueClosed),
343 }
344 }
345
346 pub fn poll_recv(
348 &self,
349 cx: &mut std::task::Context<'_>,
350 ) -> std::task::Poll<Result<Bytes, SnapTunNgSocketError>> {
351 let mut receiver = ready!(pin::pin!(self.receive_queue.lock()).poll(cx));
352 match receiver.poll_recv(cx) {
353 std::task::Poll::Ready(Some(packet)) => std::task::Poll::Ready(Ok(packet.into())),
354 std::task::Poll::Ready(None) => {
355 tracing::trace!("receive queue closed, returning error");
356 std::task::Poll::Ready(Err(SnapTunNgSocketError::ReceiveQueueClosed))
357 }
358 std::task::Poll::Pending => std::task::Poll::Pending,
359 }
360 }
361
362 pub fn local_addr(&self) -> SocketAddr {
364 self.local_sockaddr
365 }
366
367 pub async fn writable(&self) -> io::Result<()> {
369 self.underlay_socket.writable().await
370 }
371}