rustydht_lib/dht/
dht.rs

1use anyhow::anyhow;
2
3use rand::prelude::SliceRandom;
4use rand::{thread_rng, Rng};
5
6use futures::StreamExt;
7use tokio::net::{lookup_host, UdpSocket};
8use tokio::sync::mpsc;
9use tokio::time::sleep;
10
11use log::{debug, error, info, trace, warn};
12
13extern crate crc;
14use crc::{crc32, Hasher32};
15
16use std::convert::TryInto;
17use std::net::{IpAddr, SocketAddr};
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::time::{Duration, Instant};
21
22use crate::common::ipv4_addr_src::IPV4AddrSource;
23use crate::common::{Id, Node};
24use crate::dht::dht_event::{DHTEvent, DHTEventType, MessageReceivedEvent};
25use crate::dht::socket::DHTSocket;
26use crate::dht::DHTSettings;
27use crate::errors::RustyDHTError;
28use crate::packets;
29use crate::packets::MessageBuilder;
30use crate::shutdown;
31use crate::storage::node_bucket_storage::NodeStorage;
32use crate::storage::node_wrapper::NodeWrapper;
33use crate::storage::peer_storage::{PeerInfo, PeerStorage};
34use crate::storage::throttler::Throttler;
35
36struct DHTState {
37    ip4_source: Box<dyn IPV4AddrSource + Send>,
38    our_id: Id,
39    buckets: Box<dyn NodeStorage + Send>,
40    peer_storage: PeerStorage,
41    token_secret: Vec<u8>,
42    old_token_secret: Vec<u8>,
43    settings: DHTSettings,
44    subscribers: Vec<mpsc::Sender<DHTEvent>>,
45}
46
47/// This struct is the heart of the library - contains data structure and business logic to run a DHT node.
48pub struct DHT {
49    socket: Arc<DHTSocket>,
50
51    /// Coarse-grained locking for stuff what needs it
52    state: Arc<Mutex<DHTState>>,
53
54    shutdown: shutdown::ShutdownReceiver,
55}
56
57impl DHT {
58    /// Returns the current Id used by the DHT.
59    pub fn get_id(&self) -> Id {
60        self.state.lock().unwrap().our_id
61    }
62
63    /// Returns a full dump of all the info hashes and peers in storage.
64    /// Peers that haven't announced since the provided `newer_than` can be optionally filtered.
65    pub fn get_info_hashes(&self, newer_than: Option<Instant>) -> Vec<(Id, Vec<PeerInfo>)> {
66        let state = self.state.lock().unwrap();
67        let hashes = state.peer_storage.get_info_hashes();
68        hashes
69            .iter()
70            .copied()
71            .map(|hash| (hash, state.peer_storage.get_peers_info(&hash, newer_than)))
72            .filter(|tup| !tup.1.is_empty())
73            .collect()
74    }
75
76    /// Returns information about all currently-verified DHT nodes that we're "connected" with.
77    pub fn get_nodes(&self) -> Vec<NodeWrapper> {
78        self.state.lock().unwrap().buckets.get_all_verified()
79    }
80
81    /// Return a copy of the settings used by the DHT
82    pub fn get_settings(&self) -> DHTSettings {
83        self.state.lock().unwrap().settings.clone()
84    }
85
86    /// Creates a new DHT.
87    ///
88    /// # Arguments
89    /// * `shutdown` - the DHT passes this to any sub-tasks that it spawns, and uses it to know when to stop its event own event loop.
90    /// * `id` - an optional initial Id for the DHT. The DHT may change its Id if at some point its not valid for the external IPv4 address (as reported by ip4_source).
91    /// * `listen_port` - the port that the DHT should bind its UDP socket on.
92    /// * `ip4_source` - Some type that implements IPV4AddrSource. This object will be used by the DHT to keep up to date on its IPv4 address.
93    /// * `buckets` - A function that takes an Id and returns a struct implementing NodeStorage. The NodeStorage-implementing type will be used to keep the nodes
94    /// (or routing table) of the DHT.
95    /// * `routers` - Array of string slices with hostname:port of DHT routers. These help us get bootstrapped onto the network.
96    /// * `settings` - DHTSettings struct containing settings that DHT will use.
97    pub fn new(
98        shutdown: shutdown::ShutdownReceiver,
99        id: Option<Id>,
100        socket_addr: std::net::SocketAddr,
101        ip4_source: Box<dyn IPV4AddrSource + Send>,
102        mut buckets: Box<dyn NodeStorage + Send>,
103        settings: DHTSettings,
104    ) -> Result<DHT, RustyDHTError> {
105        // If we were given a hardcoded id, use that until/unless we decide its invalid based on IP source.
106        // If we weren't given a hardcoded id, try to generate one based on IP source.
107        // Finally, if all else fails, generate a totally random id.
108        let our_id = {
109            match id {
110                Some(id) => id,
111
112                None => match ip4_source.get_best_ipv4() {
113                    Some(ip) => {
114                        let id = Id::from_ip(&IpAddr::V4(ip));
115                        info!(target: "rustydht_lib::DHT",
116                            "Our external IPv4 is {:?}. Generated id {} based on that",
117                            ip, id
118                        );
119                        id
120                    }
121
122                    None => {
123                        let id = Id::from_random(&mut thread_rng());
124                        info!(target: "rustydht_lib::DHT", "No external IPv4 provided. Using random id {} for now.", id);
125                        id
126                    }
127                },
128            }
129        };
130
131        buckets.set_id(our_id);
132
133        // Setup our UDP socket
134        let socket = {
135            let std_sock = std::net::UdpSocket::bind(socket_addr)
136                .map_err(|e| RustyDHTError::GeneralError(e.into()))?;
137            std_sock
138                .set_nonblocking(true)
139                .map_err(|e| RustyDHTError::GeneralError(e.into()))?;
140            let tokio_sock =
141                UdpSocket::from_std(std_sock).map_err(|e| RustyDHTError::GeneralError(e.into()))?;
142            Arc::new(DHTSocket::new(shutdown.clone(), tokio_sock))
143        };
144
145        let token_secret = make_token_secret(settings.token_secret_size);
146
147        let dht = DHT {
148            socket,
149            state: Arc::new(Mutex::new(DHTState {
150                ip4_source,
151                our_id,
152                buckets,
153                peer_storage: PeerStorage::new(
154                    settings.max_torrents,
155                    settings.max_peers_per_torrent,
156                ),
157                token_secret: token_secret.clone(),
158                old_token_secret: token_secret,
159                settings,
160                subscribers: vec![],
161            })),
162
163            shutdown,
164        };
165
166        Ok(dht)
167    }
168
169    /// Runs the main event loop of the DHT.
170    ///
171    /// It will only return if there's an error or if the DHT's ShutdownReceiver is signalled to stop the DHT.
172    pub async fn run_event_loop(&self) -> Result<(), RustyDHTError> {
173        match tokio::try_join!(
174            // One-time
175            self.ping_routers(self.shutdown.clone()),
176            // Loop indefinitely
177            self.accept_incoming_packets(),
178            self.periodic_router_ping(self.shutdown.clone()),
179            self.periodic_buddy_ping(self.shutdown.clone()),
180            self.periodic_find_node(self.shutdown.clone()),
181            self.periodic_ip4_maintenance(),
182            self.periodic_token_rotation(),
183            async {
184                let to_ret: Result<(), RustyDHTError> = Err(RustyDHTError::ShutdownError(anyhow!(
185                    "run_event_loop should shutdown"
186                )));
187                self.shutdown.clone().watch().await;
188                to_ret
189            }
190        ) {
191            Ok(_) => Ok(()),
192            Err(e) => {
193                if let RustyDHTError::ShutdownError(_) = e {
194                    Ok(())
195                } else {
196                    Err(e)
197                }
198            }
199        }
200    }
201
202    /// Sends a [Message](crate::packets::Message), awaits and returns a response.
203    ///
204    /// Note that `req` must be a request message (not a response or error message),
205    /// as this method awaits a reply. DHT automatically handles sending responses for
206    /// incoming requests.
207    ///
208    /// # Arguments
209    /// * `req` - the message that should be sent
210    /// * `dest` - the IP/port of the intended recipient
211    /// * `dest_id` - the Id of the DHT node listening at `dest`, if known. Otherwise, `None` can be provided.
212    /// * `timeout` - An optional timeout. If supplied, this function will return
213    /// a [RustyDHTError::TimeoutError](crate::errors::RustyDHTError::TimeoutError) if `dest` does not reply
214    /// to the message within the allotted time.
215    pub async fn send_request(
216        &self,
217        req: packets::Message,
218        dest: SocketAddr,
219        dest_id: Option<Id>,
220        timeout: Option<Duration>,
221    ) -> Result<packets::Message, RustyDHTError> {
222        match timeout {
223            Some(timeout) => match tokio::time::timeout(
224                timeout,
225                DHT::common_send_and_handle_response(
226                    self.state.clone(),
227                    self.socket.clone(),
228                    req.clone(),
229                    dest,
230                    dest_id,
231                ),
232            )
233            .await
234            {
235                Ok(result) => result,
236                Err(_) => Err(RustyDHTError::TimeoutError(anyhow!(
237                    "Timed out after {:?} waiting for {} to respond to {:?}",
238                    timeout,
239                    dest,
240                    req
241                ))),
242            },
243            None => {
244                DHT::common_send_and_handle_response(
245                    self.state.clone(),
246                    self.socket.clone(),
247                    req.clone(),
248                    dest,
249                    dest_id,
250                )
251                .await
252            }
253        }
254    }
255
256    /// Subscribe to DHTEvent notifications from the DHT.
257    ///
258    /// When you're sick of receiving events from the DHT, just drop the receiver.
259    pub fn subscribe(&self) -> mpsc::Receiver<DHTEvent> {
260        let (tx, rx) = mpsc::channel(32);
261        let mut state = self.state.lock().unwrap();
262        state.subscribers.push(tx);
263        rx
264    }
265}
266
267impl DHT {
268    async fn accept_incoming_packets(&self) -> Result<(), RustyDHTError> {
269        let mut throttler = {
270            let settings = &self.state.lock().unwrap().settings;
271            Throttler::<32>::new(
272                settings.throttle_packet_count,
273                Duration::from_secs(settings.throttle_period_secs),
274                Duration::from_secs(settings.throttle_naughty_timeout_secs),
275                Duration::from_secs(settings.throttle_max_tracking_secs),
276            )
277        };
278        let read_only = self.state.lock().unwrap().settings.read_only;
279        loop {
280            match async {
281                let (msg, addr) = self.socket.recv_from().await?;
282
283                // Drop the packet if the IP has been throttled.
284                if throttler.check_throttle(addr.ip(), None, None) {
285                    return Ok(());
286                }
287
288                // Filter out packets sent from port 0. We can't reply to these.
289                if addr.port() == 0 {
290                    warn!(target: "rustydht_lib::DHT", "{} has invalid port - dropping packet", addr);
291                    return Ok(());
292                }
293
294                // Respond to requests, but only if we're not read-only
295                if !read_only {
296                    self.accept_single_packet(msg.clone(), addr).await?;
297                }
298
299                // Send a MessageReceivedEvent to any subscribers
300                self.send_packet_to_subscribers(msg, addr).await;
301
302                Ok::<(), RustyDHTError>(())
303            }.await {
304                Ok(_) => continue,
305
306                Err(err) => match err {
307                    RustyDHTError::PacketParseError(internal) => {
308                        warn!(target: "rustydht_lib::DHT", "Packet parsing error: {:?}", internal);
309                        continue;
310                    }
311
312                    RustyDHTError::ConntrackError(e) => {
313                        warn!(target: "rustydht_lib::DHT", "Connection tracking error: {:?}", e);
314                        continue;
315                    }
316
317                    _ => {
318                        return Err(err);
319                    }
320                },
321            }
322        }
323    }
324
325    /// Carries out some common tasks for each incoming request
326    ///
327    /// 1. Determines if the requester's id is valid for their IP
328    /// 2. Makes sure they have a chance to join the routing table
329    fn common_request_handling(
330        &self,
331        remote_addr: SocketAddr,
332        msg: &packets::Message,
333    ) -> Result<(), RustyDHTError> {
334        let sender_id = match msg.get_author_id() {
335            Some(sender_id) => sender_id,
336            None => {
337                return Err(RustyDHTError::PacketParseError(anyhow!(
338                    "Failed to extract sender's id"
339                )));
340            }
341        };
342
343        // Is id valid for IP?
344        let is_id_valid = sender_id.is_valid_for_ip(&remote_addr.ip());
345        let read_only = match msg.read_only {
346            Some(ro) => ro,
347            _ => false,
348        };
349        if is_id_valid && !read_only {
350            self.state
351                .lock()
352                .unwrap()
353                .buckets
354                .add_or_update(Node::new(sender_id, remote_addr), false);
355        }
356        Ok(())
357    }
358
359    async fn accept_single_packet(
360        &self,
361        msg: packets::Message,
362        addr: SocketAddr,
363    ) -> Result<(), RustyDHTError> {
364        match &msg.message_type {
365            packets::MessageType::Request(request_variant) => {
366                match request_variant {
367                    packets::RequestSpecific::PingRequest(arguments) => {
368                        self.common_request_handling(addr, &msg)?;
369
370                        // Build a ping reply
371                        let reply = MessageBuilder::new_ping_response()
372                            .sender_id(self.state.lock().unwrap().our_id)
373                            .transaction_id(msg.transaction_id.clone())
374                            .requester_ip(addr)
375                            .build()?;
376                        self.socket
377                            .send_to(reply, addr, Some(arguments.requester_id))
378                            .await?;
379                    }
380
381                    packets::RequestSpecific::GetPeersRequest(arguments) => {
382                        self.common_request_handling(addr, &msg)?;
383                        let reply = {
384                            let state = self.state.lock().unwrap();
385
386                            // First, see if we have any peers for their info_hash
387                            let peers = {
388                                let newer_than = Instant::now().checked_sub(Duration::from_secs(
389                                    state.settings.get_peers_freshness_secs,
390                                ));
391                                let mut peers = state
392                                    .peer_storage
393                                    .get_peers(&arguments.info_hash, newer_than);
394                                peers.truncate(state.settings.max_peers_response);
395                                peers
396                            };
397                            let token = calculate_token(&addr, state.token_secret.clone());
398
399                            match peers.len() {
400                                0 => {
401                                    let nearest = state.buckets.get_nearest_nodes(
402                                        &arguments.info_hash,
403                                        Some(&arguments.requester_id),
404                                    );
405
406                                    MessageBuilder::new_get_peers_response()
407                                        .sender_id(state.our_id)
408                                        .transaction_id(msg.transaction_id)
409                                        .requester_ip(addr)
410                                        .token(token.to_vec())
411                                        .nodes(nearest)
412                                        .build()?
413                                }
414
415                                _ => MessageBuilder::new_get_peers_response()
416                                    .sender_id(state.our_id)
417                                    .transaction_id(msg.transaction_id)
418                                    .requester_ip(addr)
419                                    .token(token.to_vec())
420                                    .peers(peers)
421                                    .build()?,
422                            }
423                        };
424                        self.socket
425                            .send_to(reply, addr, Some(arguments.requester_id))
426                            .await?;
427                    }
428
429                    packets::RequestSpecific::FindNodeRequest(arguments) => {
430                        self.common_request_handling(addr, &msg)?;
431                        let reply = {
432                            let state = self.state.lock().unwrap();
433                            let nearest = state.buckets.get_nearest_nodes(
434                                &arguments.target,
435                                Some(&arguments.requester_id),
436                            );
437                            MessageBuilder::new_find_node_response()
438                                .sender_id(state.our_id)
439                                .transaction_id(msg.transaction_id)
440                                .requester_ip(addr)
441                                .nodes(nearest)
442                                .build()?
443                        };
444
445                        self.socket
446                            .send_to(reply, addr, Some(arguments.requester_id))
447                            .await?;
448                    }
449
450                    packets::RequestSpecific::AnnouncePeerRequest(arguments) => {
451                        self.common_request_handling(addr, &msg)?;
452                        let reply = {
453                            let mut state = self.state.lock().unwrap();
454
455                            let is_token_valid = arguments.token
456                                == calculate_token(&addr, state.token_secret.clone())
457                                || arguments.token
458                                    == calculate_token(&addr, state.old_token_secret.clone());
459
460                            if is_token_valid {
461                                let sockaddr = match arguments.implied_port {
462                                    Some(implied_port) if implied_port => addr,
463
464                                    _ => {
465                                        let mut tmp = addr;
466                                        tmp.set_port(arguments.port);
467                                        tmp
468                                    }
469                                };
470
471                                state
472                                    .peer_storage
473                                    .announce_peer(arguments.info_hash, sockaddr);
474
475                                Some(
476                                    MessageBuilder::new_announce_peer_response()
477                                        .sender_id(state.our_id)
478                                        .transaction_id(msg.transaction_id.clone())
479                                        .requester_ip(addr)
480                                        .build()?,
481                                )
482                            } else {
483                                None
484                            }
485                        };
486
487                        if let Some(reply) = reply {
488                            self.socket
489                                .send_to(reply, addr, Some(arguments.requester_id))
490                                .await?;
491                        }
492                    }
493
494                    packets::RequestSpecific::SampleInfoHashesRequest(arguments) => {
495                        self.common_request_handling(addr, &msg)?;
496                        let reply = {
497                            let state = self.state.lock().unwrap();
498
499                            let nearest = state.buckets.get_nearest_nodes(
500                                &arguments.target,
501                                Some(&arguments.requester_id),
502                            );
503
504                            let (info_hashes, total_info_hashes) = {
505                                let info_hashes = state.peer_storage.get_info_hashes();
506                                let total_info_hashes = info_hashes.len();
507                                let info_hashes = {
508                                    let mut rng = thread_rng();
509                                    state
510                                        .peer_storage
511                                        .get_info_hashes()
512                                        .as_mut_slice()
513                                        .partial_shuffle(
514                                            &mut rng,
515                                            state.settings.max_sample_response,
516                                        )
517                                        .0
518                                        .to_vec()
519                                };
520                                (info_hashes, total_info_hashes)
521                            };
522
523                            MessageBuilder::new_sample_infohashes_response()
524                                .sender_id(state.our_id)
525                                .transaction_id(msg.transaction_id)
526                                .requester_ip(addr)
527                                .interval(Duration::from_secs(
528                                    state.settings.min_sample_interval_secs.try_into().unwrap(),
529                                ))
530                                .nodes(nearest)
531                                .samples(info_hashes)
532                                .num_infohashes(total_info_hashes)
533                                .build()?
534                        };
535
536                        self.socket
537                            .send_to(reply, addr, Some(arguments.requester_id))
538                            .await?;
539                    }
540                }
541            }
542
543            packets::MessageType::Response(_) => { /*Responses should be handled by the sender via notification channel.*/
544            }
545
546            _ => {
547                warn!(target: "rustydht_lib::DHT",
548                    "Received unsupported/unexpected KRPCMessage variant from {:?}: {:?}",
549                    addr, msg
550                );
551            }
552        }
553
554        Ok(())
555    }
556
557    async fn send_packet_to_subscribers(&self, msg: packets::Message, _addr: SocketAddr) {
558        // Notify any subscribers about the event
559        let event = DHTEvent {
560            event_type: DHTEventType::MessageReceived(MessageReceivedEvent { message: msg }),
561        };
562        let mut state = self.state.lock().unwrap();
563        state.subscribers.retain(|sub| {
564            eprintln!("Gotta do notifications for {:?}", event);
565            match sub.try_send(event.clone()) {
566                Ok(()) => true,
567                Err(e) => match e {
568                    tokio::sync::mpsc::error::TrySendError::Closed(_) => {
569                        // Remove the sender from the subscriptions since they hung up on us (rude)
570                        trace!(target: "rustydht_lib::DHT", "Removing channel for closed DHTEvent subscriber");
571                        false
572                    }
573                    tokio::sync::mpsc::error::TrySendError::Full(_) => {
574                        warn!(target: "rustydht_lib::DHT", "DHTEvent subscriber channel is full - can't send event {:?}", event);
575                        true
576                    }
577                }
578            }
579        });
580    }
581
582    async fn periodic_buddy_ping(
583        &self,
584        shutdown: shutdown::ShutdownReceiver,
585    ) -> Result<(), RustyDHTError> {
586        loop {
587            let ping_check_interval_secs =
588                self.state.lock().unwrap().settings.ping_check_interval_secs;
589            sleep(Duration::from_secs(ping_check_interval_secs)).await;
590
591            // Package things that need state into a block so that Rust will not complain about MutexGuard kept past .await
592            let reverify_interval_secs = {
593                let mut state = self.state.lock().unwrap();
594                let count = state.buckets.count();
595                debug!(target: "rustydht_lib::DHT",
596                    "Pruning node buckets. Storage has {} unverified, {} verified",
597                    count.0,
598                    count.1,
599                );
600                let reverify_grace_period_secs = state.settings.reverify_grace_period_secs;
601                let verify_grace_period_secs = state.settings.verify_grace_period_secs;
602                state.buckets.prune(
603                    Duration::from_secs(reverify_grace_period_secs),
604                    Duration::from_secs(verify_grace_period_secs),
605                );
606
607                state.settings.reverify_interval_secs
608            };
609            match Instant::now().checked_sub(Duration::from_secs(reverify_interval_secs)) {
610                None => {
611                    debug!(target: "rustydht_lib::DHT", "Monotonic clock underflow - skipping this round of pings");
612                }
613
614                Some(ping_if_older_than) => {
615                    debug!(target: "rustydht_lib::DHT", "Sending pings to all nodes that have never verified or haven't been verified in a while");
616                    let (unverified, verified) = {
617                        let state = self.state.lock().unwrap();
618                        (
619                            state.buckets.get_all_unverified(),
620                            state.buckets.get_all_verified(),
621                        )
622                    };
623                    // Ping everybody we haven't verified
624                    for wrapper in unverified {
625                        // Some things in here are actually verified... don't bother them too often
626                        if let Some(last_verified) = wrapper.last_verified {
627                            if last_verified >= ping_if_older_than {
628                                continue;
629                            }
630                            trace!(target: "rustydht_lib::DHT", "Sending ping to reverify backup {:?}", wrapper.node);
631                        } else {
632                            trace!(target: "rustydht_lib::DHT",
633                                "Sending ping to verify {:?} (last seen {} seconds ago)",
634                                wrapper.node,
635                                (Instant::now() - wrapper.last_seen).as_secs()
636                            );
637                        }
638                        self.ping_internal(
639                            shutdown.clone(),
640                            wrapper.node.address,
641                            Some(wrapper.node.id),
642                        )
643                        .await?;
644                    }
645
646                    // Reverify those who haven't been verified recently
647                    for wrapper in verified {
648                        if let Some(last_verified) = wrapper.last_verified {
649                            if last_verified >= ping_if_older_than {
650                                continue;
651                            }
652                        }
653                        trace!(target: "rustydht_lib::DHT", "Sending ping to reverify {:?}", wrapper.node);
654                        self.ping_internal(
655                            shutdown.clone(),
656                            wrapper.node.address,
657                            Some(wrapper.node.id),
658                        )
659                        .await?;
660                    }
661                }
662            }
663        }
664    }
665
666    async fn periodic_find_node(
667        &self,
668        shutdown: shutdown::ShutdownReceiver,
669    ) -> Result<(), RustyDHTError> {
670        loop {
671            let find_node_interval_secs =
672                self.state.lock().unwrap().settings.find_nodes_interval_secs;
673            sleep(Duration::from_secs(find_node_interval_secs)).await;
674
675            let (count_unverified, count_verified) = self.state.lock().unwrap().buckets.count();
676
677            // If we don't know anybody, force a router ping.
678            // This is helpful if we've been asleep for a while and lost all peers
679            if count_verified == 0 {
680                self.ping_routers(shutdown.clone()).await?;
681            }
682
683            // Package things that need state into this block to avoid issues with MutexGuard kept over .await
684            let (nearest_nodes, id_near_us) = {
685                let state = self.state.lock().unwrap();
686                if count_unverified > state.settings.find_nodes_skip_count {
687                    debug!(target: "rustydht_lib::DHT", "Skipping find_node as we already have enough unverified");
688                    continue;
689                }
690
691                let id_near_us = state.our_id.make_mutant(4).unwrap();
692
693                // Find the closest nodes to ask
694                (
695                    state.buckets.get_nearest_nodes(&id_near_us, None),
696                    id_near_us,
697                )
698            };
699            trace!(
700                target: "rustydht_lib::DHT",
701                "Sending find_node to {} nodes about {:?}",
702                nearest_nodes.len(),
703                id_near_us
704            );
705            for node in nearest_nodes {
706                self.find_node_internal(shutdown.clone(), node.address, Some(node.id), id_near_us)
707                    .await?;
708            }
709        }
710    }
711
712    async fn periodic_ip4_maintenance(&self) -> Result<(), RustyDHTError> {
713        loop {
714            sleep(Duration::from_secs(10)).await;
715
716            let mut state = self.state.lock().unwrap();
717            state.ip4_source.decay();
718
719            if let Some(ip) = state.ip4_source.get_best_ipv4() {
720                let ip = IpAddr::V4(ip);
721                if !state.our_id.is_valid_for_ip(&ip) {
722                    let new_id = Id::from_ip(&ip);
723                    info!(target: "rustydht_lib::DHT",
724                        "Our current id {} is not valid for IP {}. Using new id {}",
725                        state.our_id,
726                        ip,
727                        new_id
728                    );
729                    state.our_id = new_id;
730                    state.buckets.set_id(new_id);
731                }
732            }
733        }
734    }
735
736    async fn periodic_router_ping(
737        &self,
738        shutdown: shutdown::ShutdownReceiver,
739    ) -> Result<(), RustyDHTError> {
740        loop {
741            let router_ping_interval_secs = self
742                .state
743                .lock()
744                .unwrap()
745                .settings
746                .router_ping_interval_secs;
747            sleep(Duration::from_secs(router_ping_interval_secs)).await;
748            debug!(target: "rustydht_lib::DHT", "Pinging routers");
749            let shutdown_clone = shutdown.clone();
750            self.ping_routers(shutdown_clone).await?;
751        }
752    }
753
754    async fn periodic_token_rotation(&self) -> Result<(), RustyDHTError> {
755        loop {
756            sleep(Duration::from_secs(300)).await;
757            self.rotate_token_secrets();
758        }
759    }
760
761    /// Build and send a ping to a target. Doesn't wait for a response
762    async fn ping_internal(
763        &self,
764        shutdown: shutdown::ShutdownReceiver,
765        target: SocketAddr,
766        target_id: Option<Id>,
767    ) -> Result<(), RustyDHTError> {
768        let state = self.state.clone();
769        let socket = self.socket.clone();
770        shutdown::ShutdownReceiver::spawn_with_shutdown(
771            shutdown,
772            async move {
773                let req = {
774                    let state = state.lock().unwrap();
775                    MessageBuilder::new_ping_request()
776                        .sender_id(state.our_id)
777                        .read_only(state.settings.read_only)
778                        .build()
779                        .expect("Failed to build ping packet")
780                };
781
782                if let Err(e) =
783                    DHT::common_send_and_handle_response(state, socket, req, target, target_id)
784                        .await
785                {
786                    match e {
787                        RustyDHTError::TimeoutError(e) => {
788                            debug!(target: "rustydht_lib::DHT", "Ping timed out: {}", e);
789                        }
790
791                        _ => {
792                            error!(target: "rustydht_lib::DHT", "Error during ping: {}", e);
793                        }
794                    }
795                }
796            },
797            format!("ping to {}", target),
798            Some(Duration::from_secs(5)),
799        );
800        Ok(())
801    }
802
803    /// Send a request and await on the notification channel for a response.
804    /// Then handle the response by adding the responder to routing tables,
805    /// letting them "vote" on our IPv4 address, etc.
806    ///
807    /// Note that DHTSocket guarantees that we'll only see responses to requests that we
808    /// actually sent - "spurious" or "extraneous" responses will be dropped in DHTSocket
809    /// before we see them.
810    async fn common_send_and_handle_response(
811        state: Arc<Mutex<DHTState>>,
812        socket: Arc<DHTSocket>,
813        msg: packets::Message,
814        target: SocketAddr,
815        target_id: Option<Id>,
816    ) -> Result<packets::Message, RustyDHTError> {
817        if !matches!(msg.message_type, packets::MessageType::Request(_)) {
818            return Err(RustyDHTError::GeneralError(anyhow!(
819                "This method is only for sending requests"
820            )));
821        }
822
823        let maybe_receiver = socket.send_to(msg.clone(), target, target_id).await?;
824        match maybe_receiver {
825            Some(mut receiver) => match receiver.recv().await {
826                Some(reply) => match &reply.message_type {
827                    packets::MessageType::Response(response_variant) => {
828                        // Get the id of the sender - safe to expect because all Response variants are guaranteed
829                        // to have an Id (only error doesn't)
830                        let their_id =
831                            reply.get_author_id().expect("response doesn't have Id!?");
832                        let id_is_valid = their_id.is_valid_for_ip(&target.ip());
833
834                        // Node is fit to be in our routing buckets and vote on our IPv4 only
835                        // if its id is valid for its IP.
836                        if id_is_valid {
837                            let mut state = state.lock().unwrap();
838                            DHT::ip4_vote_helper(&mut state, &target, &reply);
839                            state
840                                .buckets
841                                .add_or_update(Node::new(their_id, target), true);
842                        }
843
844                        // Special handling for find_node responses
845                        // Add the nodes we got back as "seen" (even though we haven't necessarily seen them directly yet).
846                        // They will be pinged later in an attempt to verify them.
847                        if let packets::ResponseSpecific::FindNodeResponse(args) = response_variant {
848                            let mut state = state.lock().unwrap();
849                            for node in &args.nodes {
850                                if node.id.is_valid_for_ip(&node.address.ip()) {
851                                    state.buckets.add_or_update(node.clone(), false);
852                                }
853                            }
854                        }
855
856                        Ok(reply)
857                    }
858
859                    _ => Err(RustyDHTError::GeneralError(anyhow!("Received wrong Message type as response from {}. {:?}", target, reply)))
860                },
861
862                None => Err(RustyDHTError::TimeoutError(anyhow!("Response channel was cleaned up while we were waiting for a reply from {}. Message we sent: {:?}", target, msg)))
863            },
864
865            None => Err(RustyDHTError::GeneralError(anyhow!("Didn't get a response channel after sending a request to {}. We sent: {:?}", target, msg)))
866        }
867    }
868
869    async fn ping_router<G: AsRef<str>>(
870        &self,
871        shutdown: shutdown::ShutdownReceiver,
872        hostname: G,
873    ) -> Result<(), RustyDHTError> {
874        let hostname = hostname.as_ref();
875        // Resolve and add to request storage
876        let resolve = lookup_host(hostname).await;
877        if let Err(err) = resolve {
878            // Used to only eat the specific errors corresponding to a failure to resolve,
879            // but they vary by platform and it's a pain. For now, we'll eat all host
880            // resolution errors.
881            warn!(
882                target: "rustydht_lib::DHT",
883                "Failed to resolve host {} due to error {:#?}. Try again later.",
884                hostname, err
885            );
886            return Ok(());
887        }
888
889        for socket_addr in resolve.unwrap() {
890            if socket_addr.is_ipv4() {
891                let shutdown_clone = shutdown.clone();
892                self.ping_internal(shutdown_clone, socket_addr, None)
893                    .await?;
894                break;
895            }
896        }
897        Ok(())
898    }
899
900    /// Pings some bittorrent routers
901    async fn ping_routers(
902        &self,
903        shutdown: shutdown::ShutdownReceiver,
904    ) -> Result<(), RustyDHTError> {
905        let mut futures = futures::stream::FuturesUnordered::new();
906        let routers = self.state.lock().unwrap().settings.routers.clone();
907        for hostname in routers {
908            let shutdown_clone = shutdown.clone();
909            futures.push(self.ping_router(shutdown_clone, hostname));
910        }
911        while let Some(result) = futures.next().await {
912            result?;
913        }
914        Ok(())
915    }
916
917    fn rotate_token_secrets(&self) {
918        let mut state = self.state.lock().unwrap();
919        let new_token_secret = make_token_secret(state.settings.token_secret_size);
920
921        state.old_token_secret = state.token_secret.clone();
922        state.token_secret = new_token_secret;
923        debug!(
924            target: "rustydht_lib::DHT",
925            "Rotating token secret. New secret is {:?}, old secret is {:?}",
926            state.token_secret,
927            state.old_token_secret
928        );
929    }
930
931    async fn find_node_internal(
932        &self,
933        shutdown: shutdown::ShutdownReceiver,
934        dest: SocketAddr,
935        dest_id: Option<Id>,
936        target: Id,
937    ) -> Result<(), RustyDHTError> {
938        let state = self.state.clone();
939        let socket = self.socket.clone();
940        shutdown::ShutdownReceiver::spawn_with_shutdown(
941            shutdown,
942            async move {
943                let req = {
944                    let state = state.lock().unwrap();
945                    MessageBuilder::new_find_node_request()
946                        .sender_id(state.our_id)
947                        .read_only(state.settings.read_only)
948                        .target(target)
949                        .build()
950                        .expect("Failed to build ping packet")
951                };
952
953                if let Err(e) =
954                    DHT::common_send_and_handle_response(state, socket, req, dest, dest_id).await
955                {
956                    match e {
957                        RustyDHTError::TimeoutError(e) => {
958                            debug!(target: "rustydht_lib::DHT", "find_node timed out: {}", e);
959                        }
960
961                        _ => {
962                            error!(target: "rustydht_lib::DHT", "Error during find_node: {}", e);
963                        }
964                    }
965                }
966            },
967            format!("find_node to {} for {}", dest, target),
968            Some(Duration::from_secs(5)),
969        );
970        Ok(())
971    }
972
973    /// Adds a 'vote' for whatever IP address the sender says we have.
974    fn ip4_vote_helper(state: &mut DHTState, addr: &SocketAddr, msg: &packets::Message) {
975        if let IpAddr::V4(their_ip) = addr.ip() {
976            if let Some(SocketAddr::V4(they_claim_our_sockaddr)) = &msg.requester_ip {
977                state
978                    .ip4_source
979                    .add_vote(their_ip, *they_claim_our_sockaddr.ip());
980            }
981        }
982    }
983}
984
985/// Calculates a peer announce token based on a sockaddr and some secret.
986/// Pretty positive this isn't cryptographically safe but I'm not too worried.
987/// If we care about that later we can use a proper HMAC or something.
988fn calculate_token<T: AsRef<[u8]>>(remote: &SocketAddr, secret: T) -> [u8; 4] {
989    let secret = secret.as_ref();
990    let mut digest = crc32::Digest::new(crc32::CASTAGNOLI);
991    // digest.write(&crate::packets::sockaddr_to_bytes(remote));
992    let octets = match remote.ip() {
993        std::net::IpAddr::V4(v4) => v4.octets().to_vec(),
994        std::net::IpAddr::V6(v6) => v6.octets().to_vec(),
995    };
996    digest.write(&octets);
997    digest.write(secret);
998    let checksum: u32 = digest.sum32();
999
1000    checksum.to_be_bytes()
1001}
1002
1003fn make_token_secret(size: usize) -> Vec<u8> {
1004    let mut token_secret = vec![0; size];
1005    token_secret.fill_with(|| thread_rng().gen());
1006    token_secret
1007}
1008
1009#[cfg(test)]
1010mod test {
1011    use super::*;
1012    use crate::common::ipv4_addr_src::StaticIPV4AddrSource;
1013    use crate::dht::DHTBuilder;
1014    use crate::dht::DHTSettingsBuilder;
1015    use anyhow::anyhow;
1016    use std::boxed::Box;
1017    use std::net::{Ipv4Addr, SocketAddrV4};
1018
1019    async fn make_test_dht(
1020        port: u16,
1021    ) -> (DHT, shutdown::ShutdownSender, shutdown::ShutdownReceiver) {
1022        let ipv4 = Ipv4Addr::new(1, 2, 3, 4);
1023        let phony_ip4 = Box::new(StaticIPV4AddrSource::new(ipv4));
1024        let (tx, rx) = shutdown::create_shutdown();
1025        (
1026            DHTBuilder::new()
1027                .initial_id(get_dht_id())
1028                .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
1029                .ip_source(phony_ip4)
1030                .build(rx.clone())
1031                .unwrap(),
1032            tx,
1033            rx,
1034        )
1035    }
1036
1037    #[tokio::test]
1038    async fn test_responds_to_ping() -> Result<(), RustyDHTError> {
1039        let requester_id = Id::from_random(&mut thread_rng());
1040        let ping_request = MessageBuilder::new_ping_request()
1041            .sender_id(requester_id)
1042            .build()?;
1043
1044        let port = 1948;
1045        let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1046        shutdown::ShutdownReceiver::spawn_with_shutdown(
1047            shutdown_rx,
1048            async move {
1049                dht.run_event_loop().await.unwrap();
1050            },
1051            "Test DHT",
1052            Some(Duration::from_secs(10)),
1053        );
1054
1055        let res = send_and_receive(ping_request.clone(), port).await.unwrap();
1056
1057        assert_eq!(res.transaction_id, ping_request.transaction_id);
1058        assert_eq!(
1059            res.message_type,
1060            packets::MessageType::Response(packets::ResponseSpecific::PingResponse(
1061                packets::PingResponseArguments {
1062                    responder_id: get_dht_id()
1063                }
1064            ))
1065        );
1066
1067        shutdown_tx.shutdown().await;
1068
1069        Ok(())
1070    }
1071
1072    #[tokio::test]
1073    async fn test_responds_to_get_peers() -> Result<(), RustyDHTError> {
1074        let requester_id = Id::from_random(&mut thread_rng());
1075        let desired_info_hash = Id::from_random(&mut thread_rng());
1076        let request = MessageBuilder::new_get_peers_request()
1077            .sender_id(requester_id)
1078            .target(desired_info_hash)
1079            .build()?;
1080
1081        let port = 1974;
1082        let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1083        shutdown::ShutdownReceiver::spawn_with_shutdown(
1084            shutdown_rx,
1085            async move {
1086                dht.run_event_loop().await.unwrap();
1087            },
1088            "Test DHT",
1089            Some(Duration::from_secs(10)),
1090        );
1091
1092        let res = send_and_receive(request.clone(), port).await.unwrap();
1093
1094        assert_eq!(res.transaction_id, request.transaction_id);
1095        assert!(matches!(
1096            res.message_type,
1097            packets::MessageType::Response(packets::ResponseSpecific::GetPeersResponse(
1098                packets::GetPeersResponseArguments { .. }
1099            ))
1100        ));
1101
1102        shutdown_tx.shutdown().await;
1103
1104        Ok(())
1105    }
1106
1107    #[tokio::test]
1108    async fn test_responds_to_find_node() -> Result<(), RustyDHTError> {
1109        let port = 1995;
1110        let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1111        shutdown::ShutdownReceiver::spawn_with_shutdown(
1112            shutdown_rx,
1113            async move {
1114                dht.run_event_loop().await.unwrap();
1115            },
1116            "Test DHT",
1117            Some(Duration::from_secs(10)),
1118        );
1119
1120        let requester_id = Id::from_random(&mut thread_rng());
1121        let target = Id::from_random(&mut thread_rng());
1122        let request = MessageBuilder::new_find_node_request()
1123            .sender_id(requester_id)
1124            .target(target)
1125            .build()?;
1126        let res = send_and_receive(request.clone(), port).await.unwrap();
1127
1128        assert_eq!(res.transaction_id, request.transaction_id);
1129        assert!(matches!(
1130            res.message_type,
1131            packets::MessageType::Response(packets::ResponseSpecific::FindNodeResponse(
1132                packets::FindNodeResponseArguments { .. }
1133            ))
1134        ));
1135
1136        shutdown_tx.shutdown().await;
1137
1138        Ok(())
1139    }
1140
1141    #[tokio::test]
1142    async fn test_responds_to_announce_peer() -> Result<(), RustyDHTError> {
1143        let requester_id = Id::from_random(&mut thread_rng());
1144        let info_hash = Id::from_random(&mut thread_rng());
1145        let port = 2014;
1146        let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1147        shutdown::ShutdownReceiver::spawn_with_shutdown(
1148            shutdown_rx,
1149            async move {
1150                dht.run_event_loop().await.unwrap();
1151            },
1152            "Test DHT",
1153            Some(Duration::from_secs(10)),
1154        );
1155
1156        // Send a get_peers request and get the response
1157        let reply = send_and_receive(
1158            MessageBuilder::new_get_peers_request()
1159                .sender_id(requester_id)
1160                .target(info_hash)
1161                .build()?,
1162            port,
1163        )
1164        .await
1165        .unwrap();
1166
1167        // Extract the token from the get_peers response
1168        let token = {
1169            if let packets::MessageType::Response(packets::ResponseSpecific::GetPeersResponse(
1170                packets::GetPeersResponseArguments { token, .. },
1171            )) = reply.message_type
1172            {
1173                token
1174            } else {
1175                return Err(RustyDHTError::GeneralError(anyhow!("Didn't get token")));
1176            }
1177        };
1178
1179        // Send an announce_peer request and get the response
1180        let reply = send_and_receive(
1181            MessageBuilder::new_announce_peer_request()
1182                .sender_id(requester_id)
1183                .target(info_hash)
1184                .port(1234)
1185                .token(token)
1186                .build()?,
1187            port,
1188        )
1189        .await
1190        .unwrap();
1191
1192        // The response must be a ping response
1193        assert!(matches!(
1194            reply.message_type,
1195            packets::MessageType::Response(packets::ResponseSpecific::PingResponse(
1196                packets::PingResponseArguments { .. }
1197            ))
1198        ));
1199
1200        // Send get peers again - this time we'll get a peer back (ourselves)
1201        let reply = send_and_receive(
1202            MessageBuilder::new_get_peers_request()
1203                .sender_id(requester_id)
1204                .target(info_hash)
1205                .build()?,
1206            port,
1207        )
1208        .await
1209        .unwrap();
1210
1211        eprintln!("Received {:?}", reply);
1212
1213        // Make sure we got a peer back
1214        let peers = {
1215            if let packets::MessageType::Response(packets::ResponseSpecific::GetPeersResponse(
1216                packets::GetPeersResponseArguments {
1217                    values: packets::GetPeersResponseValues::Peers(p),
1218                    ..
1219                },
1220            )) = reply.message_type
1221            {
1222                p
1223            } else {
1224                return Err(RustyDHTError::GeneralError(anyhow!("Didn't get peers")));
1225            }
1226        };
1227        assert_eq!(peers.len(), 1);
1228        assert_eq!(peers[0].port(), 1234);
1229        eprintln!("all good!");
1230        shutdown_tx.shutdown().await;
1231
1232        Ok(())
1233    }
1234
1235    #[tokio::test]
1236    async fn test_responds_to_sample_infohashes() -> Result<(), RustyDHTError> {
1237        let requester_id = Id::from_random(&mut thread_rng());
1238        let target = Id::from_random(&mut thread_rng());
1239        let request = MessageBuilder::new_sample_infohashes_request()
1240            .sender_id(requester_id)
1241            .target(target)
1242            .build()?;
1243
1244        let port = 2037;
1245        let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1246        shutdown::ShutdownReceiver::spawn_with_shutdown(
1247            shutdown_rx,
1248            async move {
1249                dht.run_event_loop().await.unwrap();
1250            },
1251            "Test DHT",
1252            Some(Duration::from_secs(10)),
1253        );
1254
1255        let res = send_and_receive(request.clone(), port).await.unwrap();
1256
1257        assert_eq!(res.transaction_id, request.transaction_id);
1258        assert!(matches!(
1259            res.message_type,
1260            packets::MessageType::Response(packets::ResponseSpecific::SampleInfoHashesResponse(
1261                packets::SampleInfoHashesResponseArguments { num: 0, .. }
1262            ))
1263        ));
1264
1265        shutdown_tx.shutdown().await;
1266
1267        Ok(())
1268    }
1269
1270    #[tokio::test]
1271    async fn test_event_loop_pings_routers() {
1272        let (mut shutdown_tx, shutdown_rx) = shutdown::create_shutdown();
1273        let port1 = 2171;
1274        let dht1 = Arc::new(
1275            DHTBuilder::new()
1276                .initial_id(get_dht_id())
1277                .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port1))
1278                .ip_source(Box::new(StaticIPV4AddrSource::new(Ipv4Addr::new(
1279                    1, 2, 3, 4,
1280                ))))
1281                .settings(DHTSettingsBuilder::new().routers(vec![]).build())
1282                .build(shutdown_rx.clone())
1283                .unwrap(),
1284        );
1285
1286        let dht2 = Arc::new(
1287            DHTBuilder::new()
1288                .initial_id(get_dht_id().make_mutant(4).unwrap())
1289                .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))
1290                .ip_source(Box::new(StaticIPV4AddrSource::new(Ipv4Addr::new(
1291                    1, 2, 3, 4,
1292                ))))
1293                .settings(
1294                    DHTSettingsBuilder::new()
1295                        .router_ping_interval_secs(1)
1296                        .routers(vec![format!("127.0.0.1:{}", port1)])
1297                        .build(),
1298                )
1299                .build(shutdown_rx.clone())
1300                .unwrap(),
1301        );
1302
1303        let mut receiver = dht2.subscribe();
1304
1305        shutdown::ShutdownReceiver::spawn_with_shutdown(
1306            shutdown_rx.clone(),
1307            async move {
1308                dht1.run_event_loop().await.unwrap();
1309            },
1310            "DHT1",
1311            None,
1312        );
1313
1314        let dht2_clone = dht2.clone();
1315        shutdown::ShutdownReceiver::spawn_with_shutdown(
1316            shutdown_rx,
1317            async move { dht2_clone.run_event_loop().await.unwrap() },
1318            "DHT2",
1319            None,
1320        );
1321
1322        receiver.recv().await;
1323        let (unverified, verified) = dht2.state.lock().unwrap().buckets.count();
1324
1325        // Must drop dht2 as it contains a ShutdownReceiver channel which will block shutdown
1326        drop(dht2);
1327
1328        shutdown_tx.shutdown().await;
1329        assert_eq!(unverified, 0);
1330        assert_eq!(verified, 1);
1331    }
1332
1333    #[tokio::test]
1334    async fn test_token_secret_rotation() {
1335        let ipv4 = Ipv4Addr::new(1, 2, 3, 4);
1336        let phony_ip4 = Box::new(StaticIPV4AddrSource::new(ipv4));
1337        let port = 2244;
1338
1339        let dht = DHTBuilder::new()
1340            .initial_id(get_dht_id())
1341            .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
1342            .ip_source(phony_ip4)
1343            .settings(DHTSettingsBuilder::new().routers(vec![]).build())
1344            .build(shutdown::create_shutdown().1)
1345            .unwrap();
1346
1347        assert_eq!(
1348            dht.state.lock().unwrap().token_secret.len(),
1349            DHTSettings::default().token_secret_size
1350        );
1351
1352        dht.rotate_token_secrets();
1353        assert_eq!(
1354            dht.state.lock().unwrap().old_token_secret.len(),
1355            DHTSettings::default().token_secret_size
1356        );
1357        assert_eq!(
1358            dht.state.lock().unwrap().token_secret.len(),
1359            DHTSettings::default().token_secret_size
1360        );
1361
1362        let state = dht.state.lock().unwrap();
1363        assert_ne!(state.old_token_secret, state.token_secret);
1364    }
1365
1366    // Dumb helper function because we can't declare a const or static Id
1367    fn get_dht_id() -> Id {
1368        Id::from_hex("0011223344556677889900112233445566778899").unwrap()
1369    }
1370
1371    // Helper function that sends a single packet to the test DHT and then returns the response
1372    async fn send_and_receive(
1373        msg: packets::Message,
1374        port: u16,
1375    ) -> Result<packets::Message, RustyDHTError> {
1376        let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1377        sock.send_to(
1378            &msg.clone().to_bytes().unwrap(),
1379            format!("127.0.0.1:{}", port),
1380        )
1381        .await
1382        .unwrap();
1383        let mut recv_buf = [0; 2048];
1384        let num_read = sock.recv_from(&mut recv_buf).await.unwrap().0;
1385        packets::Message::from_bytes(&recv_buf[..num_read])
1386    }
1387}