Skip to main content

rustybit_leechy_dht/
requester.rs

1use std::collections::VecDeque;
2use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4, ToSocketAddrs};
3use std::num::NonZeroU32;
4use std::time::Duration;
5
6use anyhow::Context;
7use governor::{DefaultDirectRateLimiter, Quota, RateLimiter};
8use tokio::net::UdpSocket;
9use tokio::sync::{mpsc, oneshot};
10use tokio::time::Instant;
11
12use crate::requests::{GetPeersQueryMessage, KrpcMessage, KrpcMessageType};
13use crate::util::generate_node_id;
14
15const DEFAULT_BUF_SIZE: usize = 65_536;
16const INFLIGHT_REQUEST_TIMEOUT_SECS: f64 = 1.;
17
18#[derive(Debug)]
19pub struct DhtRequester {
20    read_buf: Vec<u8>,
21    serialized_get_peers_message: Vec<u8>,
22    node_queue: VecDeque<SocketAddrV4>,
23    processed_nodes: Vec<SocketAddrV4>,
24    seen_peers: Vec<Ipv4Addr>,
25    inflight_requests: VecDeque<(Instant, SocketAddrV4)>,
26    rate_limiter: DefaultDirectRateLimiter,
27}
28
29impl DhtRequester {
30    pub fn new(bootstrap_node_addrs: Option<Vec<SocketAddrV4>>, info_hash: [u8; 20]) -> anyhow::Result<Self> {
31        let node_queue = VecDeque::from(bootstrap_node_addrs.unwrap_or_else(|| {
32            ["dht.transmissionbt.com:6881", "dht.libtorrent.org:25401"]
33                .iter()
34                .filter_map(|&node| {
35                    node.to_socket_addrs().ok().and_then(|mut socket_addrs| {
36                        socket_addrs.find_map(|addr| {
37                            if let SocketAddr::V4(v4_addr) = addr {
38                                Some(v4_addr)
39                            } else {
40                                None
41                            }
42                        })
43                    })
44                })
45                .collect()
46        }));
47
48        if node_queue.is_empty() {
49            anyhow::bail!("No suitable bootstrap DHT nodes found");
50        }
51
52        let message = KrpcMessage {
53            transaction_id: "10".into(),
54            message_type: KrpcMessageType::Query {
55                name: "get_peers".into(),
56                query: GetPeersQueryMessage {
57                    id: generate_node_id(),
58                    info_hash,
59                },
60            },
61        };
62
63        let rate_limiter = RateLimiter::direct(Quota::per_second(NonZeroU32::new(200).unwrap()));
64
65        let mut serialized_message = Vec::new();
66        serde_bencode::to_writer(&message, &mut serialized_message)
67            .context("failed to serialize a get_peers DHT message")?;
68
69        Ok(DhtRequester {
70            read_buf: vec![0; DEFAULT_BUF_SIZE],
71            serialized_get_peers_message: serialized_message,
72            node_queue,
73            processed_nodes: Vec::new(),
74            inflight_requests: VecDeque::new(),
75            seen_peers: Vec::new(),
76            rate_limiter,
77        })
78    }
79
80    #[tracing::instrument(level = "debug", err, skip_all)]
81    pub async fn process_dht_nodes(
82        &mut self,
83        mut cancellation: oneshot::Receiver<()>,
84        peer_queue_sender: mpsc::Sender<SocketAddrV4>,
85    ) -> anyhow::Result<()> {
86        let socket = UdpSocket::bind("0.0.0.0:6881").await.context("binding UDP socket")?;
87
88        let mut request_cleanup_interval =
89            tokio::time::interval(Duration::from_secs_f64(INFLIGHT_REQUEST_TIMEOUT_SECS));
90
91        'main: loop {
92            tokio::select! {
93                result = socket.recv_from(&mut self.read_buf) => {
94                    let (read_bytes, from_node) = result.context("receiving a message from a node")?;
95                    let from_node = match from_node {
96                        SocketAddr::V4(addr) => addr,
97                        _ => {
98                            tracing::debug!("received a message from node using IPv6: {}", from_node);
99                            continue;
100                        }
101                    };
102
103                    self.inflight_requests.retain(|(_, node_addr)| node_addr != &from_node);
104
105                    tracing::trace!("Received a response from DHT node: {}", from_node);
106
107                    match serde_bencode::from_bytes::<KrpcMessage>(&self.read_buf[..read_bytes]) {
108                        Ok(get_peers_response) => {
109                            match get_peers_response.message_type {
110                                KrpcMessageType::Query { query, .. } => {
111                                    tracing::warn!(
112                                        addr = %from_node,
113                                        "Unexpected response from DHT node. Expected get_peers response, got Query: {:?}",
114                                        query
115                                    )
116                                },
117                                KrpcMessageType::Response { response }=> {
118                                    if !response.nodes.is_some() && !response.values.is_some() {
119                                        tracing::debug!(
120                                            addr = %from_node,
121                                            "Bad get_peers response from node: no nodes or peers",
122                                        )
123                                    }
124
125                                    if let Some(nodes) = response.nodes {
126                                        for compact_node_info in nodes.chunks(26) {
127                                            let compact_node_addr = &compact_node_info[20..];
128                                            let ip = TryInto::<[u8; 4]>::try_into(&compact_node_addr[..4]).context("converting IP from slice to array")?;
129                                            let port = ((compact_node_addr[4] as u16) << 8) | compact_node_addr[5] as u16;
130
131                                            let ip_addr = Ipv4Addr::from(ip);
132                                            let node_addr = SocketAddrV4::new(ip_addr, port);
133
134                                            if !self.processed_nodes.contains(&node_addr) && !self.node_queue.contains(&node_addr) {
135                                                self.node_queue.push_back(node_addr);
136                                            }
137                                        }
138                                    }
139
140                                    if let Some(peers) = response.values {
141                                        for compact_peer_addr in peers.iter() {
142                                            let ip = TryInto::<[u8; 4]>::try_into(&compact_peer_addr.0[..4]).context("converting IP from slice to array")?;
143                                            let port = ((compact_peer_addr.0[4] as u16) << 8) | compact_peer_addr.0[5] as u16;
144
145                                            let ip_addr = Ipv4Addr::from(ip);
146                                            if !self.seen_peers.contains(&ip_addr) {
147                                                self.seen_peers.push(ip_addr);
148                                                let peer_addr = SocketAddrV4::new(ip_addr, port);
149                                                match peer_queue_sender.send(peer_addr).await {
150                                                    Ok(_) => {},
151                                                    Err(_) => {
152                                                        tracing::debug!("Receiving half of the peer queue sender was dropped, shutting down...");
153                                                        break 'main;
154                                                    }
155
156                                                };
157                                            }
158                                        }
159                                    }
160                                },
161                                KrpcMessageType::Error { error } => {
162                                    tracing::debug!(
163                                        addr = %from_node,
164                                        "Got Error from DHT node: {:?}",
165                                        error
166                                    )
167                                }
168                            }
169                        }
170                        Err(e) => tracing::debug!(addr = %from_node, "An error happened while decoding a message from DHT node: {}", e)}
171                }
172                _ = self.rate_limiter.until_ready(), if !self.node_queue.is_empty() => {
173                    self
174                        .query_next_node(&socket)
175                        .await
176                        .context("querying next node in the queue")?;
177                }
178                _ = request_cleanup_interval.tick() => {}
179                _ = &mut cancellation => {
180                    tracing::debug!("Cancellation requsted. Requester is exiting");
181                    break;
182                }
183            }
184
185            // Drop oldest inflight requests to allow making new ones
186            if self
187                .inflight_requests
188                .front()
189                .is_some_and(|(req_time, _)| req_time.elapsed().as_secs_f64() > INFLIGHT_REQUEST_TIMEOUT_SECS)
190            {
191                self.inflight_requests
192                    .retain(|(req_time, _)| req_time.elapsed().as_secs_f64() > INFLIGHT_REQUEST_TIMEOUT_SECS);
193            }
194
195            if self.node_queue.is_empty() && self.inflight_requests.is_empty() {
196                // There is nothing left to do, as there are no more nodes to query
197                tracing::debug!("No DHT nodes left to query. Requester is exiting");
198                break;
199            }
200        }
201
202        Ok(())
203    }
204
205    #[tracing::instrument(level = "debug", err, skip_all)]
206    pub async fn query_next_node(&mut self, socket: &UdpSocket) -> anyhow::Result<bool> {
207        let Some(next_node) = self.node_queue.pop_front() else {
208            // We can't continue from this point, as there are no other nodes to request data from.
209            return Ok(false);
210        };
211
212        // Mark a node as processed, so that we won't try it again if some peer sends it back to us
213        // again
214        self.processed_nodes.push(next_node);
215
216        // Recored request time
217        let request_time = Instant::now();
218        self.inflight_requests.push_back((request_time, next_node));
219
220        socket
221            .send_to(&self.serialized_get_peers_message, next_node)
222            .await
223            .with_context(|| format!("sending a get_peers message to '{}'", next_node))?;
224
225        Ok(true)
226    }
227}