torrust_tracker/shared/bit_torrent/tracker/udp/
client.rs

1use core::result::Result::{Err, Ok};
2use std::io::Cursor;
3use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr};
4use std::sync::Arc;
5use std::time::Duration;
6
7use aquatic_udp_protocol::{ConnectRequest, Request, Response, TransactionId};
8use tokio::net::UdpSocket;
9use tokio::time;
10use torrust_tracker_configuration::DEFAULT_TIMEOUT;
11use zerocopy::network_endian::I32;
12
13use super::Error;
14use crate::shared::bit_torrent::tracker::udp::MAX_PACKET_SIZE;
15
16pub const UDP_CLIENT_LOG_TARGET: &str = "UDP CLIENT";
17
18#[allow(clippy::module_name_repetitions)]
19#[derive(Debug)]
20pub struct UdpClient {
21    /// The socket to connect to
22    pub socket: Arc<UdpSocket>,
23
24    /// Timeout for sending and receiving packets
25    pub timeout: Duration,
26}
27
28impl UdpClient {
29    /// Creates a new `UdpClient` bound to the default port and ipv6 address
30    ///
31    /// # Errors
32    ///
33    /// Will return error if unable to bind to any port or ip address.
34    ///
35    async fn bound_to_default_ipv4(timeout: Duration) -> Result<Self, Error> {
36        let addr = SocketAddr::new(Ipv4Addr::UNSPECIFIED.into(), 0);
37
38        Self::bound(addr, timeout).await
39    }
40
41    /// Creates a new `UdpClient` bound to the default port and ipv6 address
42    ///
43    /// # Errors
44    ///
45    /// Will return error if unable to bind to any port or ip address.
46    ///
47    async fn bound_to_default_ipv6(timeout: Duration) -> Result<Self, Error> {
48        let addr = SocketAddr::new(Ipv6Addr::UNSPECIFIED.into(), 0);
49
50        Self::bound(addr, timeout).await
51    }
52
53    /// Creates a new `UdpClient` connected to a Udp server
54    ///
55    /// # Errors
56    ///
57    /// Will return any errors present in the call stack
58    ///
59    pub async fn connected(remote_addr: SocketAddr, timeout: Duration) -> Result<Self, Error> {
60        let client = if remote_addr.is_ipv4() {
61            Self::bound_to_default_ipv4(timeout).await?
62        } else {
63            Self::bound_to_default_ipv6(timeout).await?
64        };
65
66        client.connect(remote_addr).await?;
67        Ok(client)
68    }
69
70    /// Creates a `[UdpClient]` bound to a Socket.
71    ///
72    /// # Panics
73    ///
74    /// Panics if unable to get the `local_addr` of the bound socket.
75    ///
76    /// # Errors
77    ///
78    /// This function will return an error if the binding takes to long
79    /// or if there is an underlying OS error.
80    pub async fn bound(addr: SocketAddr, timeout: Duration) -> Result<Self, Error> {
81        tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "binding to socket: {addr:?} ...");
82
83        let socket = time::timeout(timeout, UdpSocket::bind(addr))
84            .await
85            .map_err(|_| Error::TimeoutWhileBindingToSocket { addr })?
86            .map_err(|e| Error::UnableToBindToSocket { err: e.into(), addr })?;
87
88        let addr = socket.local_addr().expect("it should get the local address");
89
90        tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "bound to socket: {addr:?}.");
91
92        let udp_client = Self {
93            socket: Arc::new(socket),
94            timeout,
95        };
96
97        Ok(udp_client)
98    }
99
100    /// # Errors
101    ///
102    /// Will return error if can't connect to the socket.
103    pub async fn connect(&self, remote_addr: SocketAddr) -> Result<(), Error> {
104        tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "connecting to remote: {remote_addr:?} ...");
105
106        let () = time::timeout(self.timeout, self.socket.connect(remote_addr))
107            .await
108            .map_err(|_| Error::TimeoutWhileConnectingToRemote { remote_addr })?
109            .map_err(|e| Error::UnableToConnectToRemote {
110                err: e.into(),
111                remote_addr,
112            })?;
113
114        tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "connected to remote: {remote_addr:?}.");
115
116        Ok(())
117    }
118
119    /// # Errors
120    ///
121    /// Will return error if:
122    ///
123    /// - Can't write to the socket.
124    /// - Can't send data.
125    pub async fn send(&self, bytes: &[u8]) -> Result<usize, Error> {
126        tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "sending {bytes:?} ...");
127
128        let () = time::timeout(self.timeout, self.socket.writable())
129            .await
130            .map_err(|_| Error::TimeoutWaitForWriteableSocket)?
131            .map_err(|e| Error::UnableToGetWritableSocket { err: e.into() })?;
132
133        let sent_bytes = time::timeout(self.timeout, self.socket.send(bytes))
134            .await
135            .map_err(|_| Error::TimeoutWhileSendingData { data: bytes.to_vec() })?
136            .map_err(|e| Error::UnableToSendData {
137                err: e.into(),
138                data: bytes.to_vec(),
139            })?;
140
141        tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "sent {sent_bytes} bytes to remote.");
142
143        Ok(sent_bytes)
144    }
145
146    /// # Errors
147    ///
148    /// Will return error if:
149    ///
150    /// - Can't read from the socket.
151    /// - Can't receive data.
152    ///
153    /// # Panics
154    ///
155    pub async fn receive(&self) -> Result<Vec<u8>, Error> {
156        tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "receiving ...");
157
158        let mut buffer = [0u8; MAX_PACKET_SIZE];
159
160        let () = time::timeout(self.timeout, self.socket.readable())
161            .await
162            .map_err(|_| Error::TimeoutWaitForReadableSocket)?
163            .map_err(|e| Error::UnableToGetReadableSocket { err: e.into() })?;
164
165        let received_bytes = time::timeout(self.timeout, self.socket.recv(&mut buffer))
166            .await
167            .map_err(|_| Error::TimeoutWhileReceivingData)?
168            .map_err(|e| Error::UnableToReceivingData { err: e.into() })?;
169
170        let mut received: Vec<u8> = buffer.to_vec();
171        Vec::truncate(&mut received, received_bytes);
172
173        tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "received {received_bytes} bytes: {received:?}");
174
175        Ok(received)
176    }
177}
178
179#[allow(clippy::module_name_repetitions)]
180#[derive(Debug)]
181pub struct UdpTrackerClient {
182    pub client: UdpClient,
183}
184
185impl UdpTrackerClient {
186    /// Creates a new `UdpTrackerClient` connected to a Udp Tracker server
187    ///
188    /// # Errors
189    ///
190    /// If unable to connect to the remote address.
191    ///
192    pub async fn new(remote_addr: SocketAddr, timeout: Duration) -> Result<UdpTrackerClient, Error> {
193        let client = UdpClient::connected(remote_addr, timeout).await?;
194        Ok(UdpTrackerClient { client })
195    }
196
197    /// # Errors
198    ///
199    /// Will return error if can't write request to bytes.
200    pub async fn send(&self, request: Request) -> Result<usize, Error> {
201        tracing::trace!(target: UDP_CLIENT_LOG_TARGET, "sending request {request:?} ...");
202
203        // Write request into a buffer
204        // todo: optimize the pre-allocated amount based upon request type.
205        let mut writer = Cursor::new(Vec::with_capacity(200));
206        let () = request
207            .write_bytes(&mut writer)
208            .map_err(|e| Error::UnableToWriteDataFromRequest { err: e.into(), request })?;
209
210        self.client.send(writer.get_ref()).await
211    }
212
213    /// # Errors
214    ///
215    /// Will return error if can't create response from the received payload (bytes buffer).
216    pub async fn receive(&self) -> Result<Response, Error> {
217        let response = self.client.receive().await?;
218
219        tracing::debug!(target: UDP_CLIENT_LOG_TARGET, "received {} bytes: {response:?}", response.len());
220
221        Response::parse_bytes(&response, true).map_err(|e| Error::UnableToParseResponse { err: e.into(), response })
222    }
223}
224
225/// Helper Function to Check if a UDP Service is Connectable
226///
227/// # Panics
228///
229/// It will return an error if unable to connect to the UDP service.
230///
231/// # Errors
232///
233pub async fn check(remote_addr: &SocketAddr) -> Result<String, String> {
234    tracing::debug!("Checking Service (detail): {remote_addr:?}.");
235
236    match UdpTrackerClient::new(*remote_addr, DEFAULT_TIMEOUT).await {
237        Ok(client) => {
238            let connect_request = ConnectRequest {
239                transaction_id: TransactionId(I32::new(123)),
240            };
241
242            // client.send() return usize, but doesn't use here
243            match client.send(connect_request.into()).await {
244                Ok(_) => (),
245                Err(e) => tracing::debug!("Error: {e:?}."),
246            };
247
248            let process = move |response| {
249                if matches!(response, Response::Connect(_connect_response)) {
250                    Ok("Connected".to_string())
251                } else {
252                    Err("Did not Connect".to_string())
253                }
254            };
255
256            let sleep = time::sleep(Duration::from_millis(2000));
257            tokio::pin!(sleep);
258
259            tokio::select! {
260                () = &mut sleep => {
261                      Err("Timed Out".to_string())
262                }
263                response = client.receive() => {
264                      process(response.unwrap())
265                }
266            }
267        }
268        Err(e) => Err(format!("{e:?}")),
269    }
270}