1use anyhow::anyhow;
2
3use rand::prelude::SliceRandom;
4use rand::{thread_rng, Rng};
5
6use futures::StreamExt;
7use tokio::net::{lookup_host, UdpSocket};
8use tokio::sync::mpsc;
9use tokio::time::sleep;
10
11use log::{debug, error, info, trace, warn};
12
13extern crate crc;
14use crc::{crc32, Hasher32};
15
16use std::convert::TryInto;
17use std::net::{IpAddr, SocketAddr};
18use std::sync::Arc;
19use std::sync::Mutex;
20use std::time::{Duration, Instant};
21
22use crate::common::ipv4_addr_src::IPV4AddrSource;
23use crate::common::{Id, Node};
24use crate::dht::dht_event::{DHTEvent, DHTEventType, MessageReceivedEvent};
25use crate::dht::socket::DHTSocket;
26use crate::dht::DHTSettings;
27use crate::errors::RustyDHTError;
28use crate::packets;
29use crate::packets::MessageBuilder;
30use crate::shutdown;
31use crate::storage::node_bucket_storage::NodeStorage;
32use crate::storage::node_wrapper::NodeWrapper;
33use crate::storage::peer_storage::{PeerInfo, PeerStorage};
34use crate::storage::throttler::Throttler;
35
36struct DHTState {
37 ip4_source: Box<dyn IPV4AddrSource + Send>,
38 our_id: Id,
39 buckets: Box<dyn NodeStorage + Send>,
40 peer_storage: PeerStorage,
41 token_secret: Vec<u8>,
42 old_token_secret: Vec<u8>,
43 settings: DHTSettings,
44 subscribers: Vec<mpsc::Sender<DHTEvent>>,
45}
46
47pub struct DHT {
49 socket: Arc<DHTSocket>,
50
51 state: Arc<Mutex<DHTState>>,
53
54 shutdown: shutdown::ShutdownReceiver,
55}
56
57impl DHT {
58 pub fn get_id(&self) -> Id {
60 self.state.lock().unwrap().our_id
61 }
62
63 pub fn get_info_hashes(&self, newer_than: Option<Instant>) -> Vec<(Id, Vec<PeerInfo>)> {
66 let state = self.state.lock().unwrap();
67 let hashes = state.peer_storage.get_info_hashes();
68 hashes
69 .iter()
70 .copied()
71 .map(|hash| (hash, state.peer_storage.get_peers_info(&hash, newer_than)))
72 .filter(|tup| !tup.1.is_empty())
73 .collect()
74 }
75
76 pub fn get_nodes(&self) -> Vec<NodeWrapper> {
78 self.state.lock().unwrap().buckets.get_all_verified()
79 }
80
81 pub fn get_settings(&self) -> DHTSettings {
83 self.state.lock().unwrap().settings.clone()
84 }
85
86 pub fn new(
98 shutdown: shutdown::ShutdownReceiver,
99 id: Option<Id>,
100 socket_addr: std::net::SocketAddr,
101 ip4_source: Box<dyn IPV4AddrSource + Send>,
102 mut buckets: Box<dyn NodeStorage + Send>,
103 settings: DHTSettings,
104 ) -> Result<DHT, RustyDHTError> {
105 let our_id = {
109 match id {
110 Some(id) => id,
111
112 None => match ip4_source.get_best_ipv4() {
113 Some(ip) => {
114 let id = Id::from_ip(&IpAddr::V4(ip));
115 info!(target: "rustydht_lib::DHT",
116 "Our external IPv4 is {:?}. Generated id {} based on that",
117 ip, id
118 );
119 id
120 }
121
122 None => {
123 let id = Id::from_random(&mut thread_rng());
124 info!(target: "rustydht_lib::DHT", "No external IPv4 provided. Using random id {} for now.", id);
125 id
126 }
127 },
128 }
129 };
130
131 buckets.set_id(our_id);
132
133 let socket = {
135 let std_sock = std::net::UdpSocket::bind(socket_addr)
136 .map_err(|e| RustyDHTError::GeneralError(e.into()))?;
137 std_sock
138 .set_nonblocking(true)
139 .map_err(|e| RustyDHTError::GeneralError(e.into()))?;
140 let tokio_sock =
141 UdpSocket::from_std(std_sock).map_err(|e| RustyDHTError::GeneralError(e.into()))?;
142 Arc::new(DHTSocket::new(shutdown.clone(), tokio_sock))
143 };
144
145 let token_secret = make_token_secret(settings.token_secret_size);
146
147 let dht = DHT {
148 socket,
149 state: Arc::new(Mutex::new(DHTState {
150 ip4_source,
151 our_id,
152 buckets,
153 peer_storage: PeerStorage::new(
154 settings.max_torrents,
155 settings.max_peers_per_torrent,
156 ),
157 token_secret: token_secret.clone(),
158 old_token_secret: token_secret,
159 settings,
160 subscribers: vec![],
161 })),
162
163 shutdown,
164 };
165
166 Ok(dht)
167 }
168
169 pub async fn run_event_loop(&self) -> Result<(), RustyDHTError> {
173 match tokio::try_join!(
174 self.ping_routers(self.shutdown.clone()),
176 self.accept_incoming_packets(),
178 self.periodic_router_ping(self.shutdown.clone()),
179 self.periodic_buddy_ping(self.shutdown.clone()),
180 self.periodic_find_node(self.shutdown.clone()),
181 self.periodic_ip4_maintenance(),
182 self.periodic_token_rotation(),
183 async {
184 let to_ret: Result<(), RustyDHTError> = Err(RustyDHTError::ShutdownError(anyhow!(
185 "run_event_loop should shutdown"
186 )));
187 self.shutdown.clone().watch().await;
188 to_ret
189 }
190 ) {
191 Ok(_) => Ok(()),
192 Err(e) => {
193 if let RustyDHTError::ShutdownError(_) = e {
194 Ok(())
195 } else {
196 Err(e)
197 }
198 }
199 }
200 }
201
202 pub async fn send_request(
216 &self,
217 req: packets::Message,
218 dest: SocketAddr,
219 dest_id: Option<Id>,
220 timeout: Option<Duration>,
221 ) -> Result<packets::Message, RustyDHTError> {
222 match timeout {
223 Some(timeout) => match tokio::time::timeout(
224 timeout,
225 DHT::common_send_and_handle_response(
226 self.state.clone(),
227 self.socket.clone(),
228 req.clone(),
229 dest,
230 dest_id,
231 ),
232 )
233 .await
234 {
235 Ok(result) => result,
236 Err(_) => Err(RustyDHTError::TimeoutError(anyhow!(
237 "Timed out after {:?} waiting for {} to respond to {:?}",
238 timeout,
239 dest,
240 req
241 ))),
242 },
243 None => {
244 DHT::common_send_and_handle_response(
245 self.state.clone(),
246 self.socket.clone(),
247 req.clone(),
248 dest,
249 dest_id,
250 )
251 .await
252 }
253 }
254 }
255
256 pub fn subscribe(&self) -> mpsc::Receiver<DHTEvent> {
260 let (tx, rx) = mpsc::channel(32);
261 let mut state = self.state.lock().unwrap();
262 state.subscribers.push(tx);
263 rx
264 }
265}
266
267impl DHT {
268 async fn accept_incoming_packets(&self) -> Result<(), RustyDHTError> {
269 let mut throttler = {
270 let settings = &self.state.lock().unwrap().settings;
271 Throttler::<32>::new(
272 settings.throttle_packet_count,
273 Duration::from_secs(settings.throttle_period_secs),
274 Duration::from_secs(settings.throttle_naughty_timeout_secs),
275 Duration::from_secs(settings.throttle_max_tracking_secs),
276 )
277 };
278 let read_only = self.state.lock().unwrap().settings.read_only;
279 loop {
280 match async {
281 let (msg, addr) = self.socket.recv_from().await?;
282
283 if throttler.check_throttle(addr.ip(), None, None) {
285 return Ok(());
286 }
287
288 if addr.port() == 0 {
290 warn!(target: "rustydht_lib::DHT", "{} has invalid port - dropping packet", addr);
291 return Ok(());
292 }
293
294 if !read_only {
296 self.accept_single_packet(msg.clone(), addr).await?;
297 }
298
299 self.send_packet_to_subscribers(msg, addr).await;
301
302 Ok::<(), RustyDHTError>(())
303 }.await {
304 Ok(_) => continue,
305
306 Err(err) => match err {
307 RustyDHTError::PacketParseError(internal) => {
308 warn!(target: "rustydht_lib::DHT", "Packet parsing error: {:?}", internal);
309 continue;
310 }
311
312 RustyDHTError::ConntrackError(e) => {
313 warn!(target: "rustydht_lib::DHT", "Connection tracking error: {:?}", e);
314 continue;
315 }
316
317 _ => {
318 return Err(err);
319 }
320 },
321 }
322 }
323 }
324
325 fn common_request_handling(
330 &self,
331 remote_addr: SocketAddr,
332 msg: &packets::Message,
333 ) -> Result<(), RustyDHTError> {
334 let sender_id = match msg.get_author_id() {
335 Some(sender_id) => sender_id,
336 None => {
337 return Err(RustyDHTError::PacketParseError(anyhow!(
338 "Failed to extract sender's id"
339 )));
340 }
341 };
342
343 let is_id_valid = sender_id.is_valid_for_ip(&remote_addr.ip());
345 let read_only = match msg.read_only {
346 Some(ro) => ro,
347 _ => false,
348 };
349 if is_id_valid && !read_only {
350 self.state
351 .lock()
352 .unwrap()
353 .buckets
354 .add_or_update(Node::new(sender_id, remote_addr), false);
355 }
356 Ok(())
357 }
358
359 async fn accept_single_packet(
360 &self,
361 msg: packets::Message,
362 addr: SocketAddr,
363 ) -> Result<(), RustyDHTError> {
364 match &msg.message_type {
365 packets::MessageType::Request(request_variant) => {
366 match request_variant {
367 packets::RequestSpecific::PingRequest(arguments) => {
368 self.common_request_handling(addr, &msg)?;
369
370 let reply = MessageBuilder::new_ping_response()
372 .sender_id(self.state.lock().unwrap().our_id)
373 .transaction_id(msg.transaction_id.clone())
374 .requester_ip(addr)
375 .build()?;
376 self.socket
377 .send_to(reply, addr, Some(arguments.requester_id))
378 .await?;
379 }
380
381 packets::RequestSpecific::GetPeersRequest(arguments) => {
382 self.common_request_handling(addr, &msg)?;
383 let reply = {
384 let state = self.state.lock().unwrap();
385
386 let peers = {
388 let newer_than = Instant::now().checked_sub(Duration::from_secs(
389 state.settings.get_peers_freshness_secs,
390 ));
391 let mut peers = state
392 .peer_storage
393 .get_peers(&arguments.info_hash, newer_than);
394 peers.truncate(state.settings.max_peers_response);
395 peers
396 };
397 let token = calculate_token(&addr, state.token_secret.clone());
398
399 match peers.len() {
400 0 => {
401 let nearest = state.buckets.get_nearest_nodes(
402 &arguments.info_hash,
403 Some(&arguments.requester_id),
404 );
405
406 MessageBuilder::new_get_peers_response()
407 .sender_id(state.our_id)
408 .transaction_id(msg.transaction_id)
409 .requester_ip(addr)
410 .token(token.to_vec())
411 .nodes(nearest)
412 .build()?
413 }
414
415 _ => MessageBuilder::new_get_peers_response()
416 .sender_id(state.our_id)
417 .transaction_id(msg.transaction_id)
418 .requester_ip(addr)
419 .token(token.to_vec())
420 .peers(peers)
421 .build()?,
422 }
423 };
424 self.socket
425 .send_to(reply, addr, Some(arguments.requester_id))
426 .await?;
427 }
428
429 packets::RequestSpecific::FindNodeRequest(arguments) => {
430 self.common_request_handling(addr, &msg)?;
431 let reply = {
432 let state = self.state.lock().unwrap();
433 let nearest = state.buckets.get_nearest_nodes(
434 &arguments.target,
435 Some(&arguments.requester_id),
436 );
437 MessageBuilder::new_find_node_response()
438 .sender_id(state.our_id)
439 .transaction_id(msg.transaction_id)
440 .requester_ip(addr)
441 .nodes(nearest)
442 .build()?
443 };
444
445 self.socket
446 .send_to(reply, addr, Some(arguments.requester_id))
447 .await?;
448 }
449
450 packets::RequestSpecific::AnnouncePeerRequest(arguments) => {
451 self.common_request_handling(addr, &msg)?;
452 let reply = {
453 let mut state = self.state.lock().unwrap();
454
455 let is_token_valid = arguments.token
456 == calculate_token(&addr, state.token_secret.clone())
457 || arguments.token
458 == calculate_token(&addr, state.old_token_secret.clone());
459
460 if is_token_valid {
461 let sockaddr = match arguments.implied_port {
462 Some(implied_port) if implied_port => addr,
463
464 _ => {
465 let mut tmp = addr;
466 tmp.set_port(arguments.port);
467 tmp
468 }
469 };
470
471 state
472 .peer_storage
473 .announce_peer(arguments.info_hash, sockaddr);
474
475 Some(
476 MessageBuilder::new_announce_peer_response()
477 .sender_id(state.our_id)
478 .transaction_id(msg.transaction_id.clone())
479 .requester_ip(addr)
480 .build()?,
481 )
482 } else {
483 None
484 }
485 };
486
487 if let Some(reply) = reply {
488 self.socket
489 .send_to(reply, addr, Some(arguments.requester_id))
490 .await?;
491 }
492 }
493
494 packets::RequestSpecific::SampleInfoHashesRequest(arguments) => {
495 self.common_request_handling(addr, &msg)?;
496 let reply = {
497 let state = self.state.lock().unwrap();
498
499 let nearest = state.buckets.get_nearest_nodes(
500 &arguments.target,
501 Some(&arguments.requester_id),
502 );
503
504 let (info_hashes, total_info_hashes) = {
505 let info_hashes = state.peer_storage.get_info_hashes();
506 let total_info_hashes = info_hashes.len();
507 let info_hashes = {
508 let mut rng = thread_rng();
509 state
510 .peer_storage
511 .get_info_hashes()
512 .as_mut_slice()
513 .partial_shuffle(
514 &mut rng,
515 state.settings.max_sample_response,
516 )
517 .0
518 .to_vec()
519 };
520 (info_hashes, total_info_hashes)
521 };
522
523 MessageBuilder::new_sample_infohashes_response()
524 .sender_id(state.our_id)
525 .transaction_id(msg.transaction_id)
526 .requester_ip(addr)
527 .interval(Duration::from_secs(
528 state.settings.min_sample_interval_secs.try_into().unwrap(),
529 ))
530 .nodes(nearest)
531 .samples(info_hashes)
532 .num_infohashes(total_info_hashes)
533 .build()?
534 };
535
536 self.socket
537 .send_to(reply, addr, Some(arguments.requester_id))
538 .await?;
539 }
540 }
541 }
542
543 packets::MessageType::Response(_) => { }
545
546 _ => {
547 warn!(target: "rustydht_lib::DHT",
548 "Received unsupported/unexpected KRPCMessage variant from {:?}: {:?}",
549 addr, msg
550 );
551 }
552 }
553
554 Ok(())
555 }
556
557 async fn send_packet_to_subscribers(&self, msg: packets::Message, _addr: SocketAddr) {
558 let event = DHTEvent {
560 event_type: DHTEventType::MessageReceived(MessageReceivedEvent { message: msg }),
561 };
562 let mut state = self.state.lock().unwrap();
563 state.subscribers.retain(|sub| {
564 eprintln!("Gotta do notifications for {:?}", event);
565 match sub.try_send(event.clone()) {
566 Ok(()) => true,
567 Err(e) => match e {
568 tokio::sync::mpsc::error::TrySendError::Closed(_) => {
569 trace!(target: "rustydht_lib::DHT", "Removing channel for closed DHTEvent subscriber");
571 false
572 }
573 tokio::sync::mpsc::error::TrySendError::Full(_) => {
574 warn!(target: "rustydht_lib::DHT", "DHTEvent subscriber channel is full - can't send event {:?}", event);
575 true
576 }
577 }
578 }
579 });
580 }
581
582 async fn periodic_buddy_ping(
583 &self,
584 shutdown: shutdown::ShutdownReceiver,
585 ) -> Result<(), RustyDHTError> {
586 loop {
587 let ping_check_interval_secs =
588 self.state.lock().unwrap().settings.ping_check_interval_secs;
589 sleep(Duration::from_secs(ping_check_interval_secs)).await;
590
591 let reverify_interval_secs = {
593 let mut state = self.state.lock().unwrap();
594 let count = state.buckets.count();
595 debug!(target: "rustydht_lib::DHT",
596 "Pruning node buckets. Storage has {} unverified, {} verified",
597 count.0,
598 count.1,
599 );
600 let reverify_grace_period_secs = state.settings.reverify_grace_period_secs;
601 let verify_grace_period_secs = state.settings.verify_grace_period_secs;
602 state.buckets.prune(
603 Duration::from_secs(reverify_grace_period_secs),
604 Duration::from_secs(verify_grace_period_secs),
605 );
606
607 state.settings.reverify_interval_secs
608 };
609 match Instant::now().checked_sub(Duration::from_secs(reverify_interval_secs)) {
610 None => {
611 debug!(target: "rustydht_lib::DHT", "Monotonic clock underflow - skipping this round of pings");
612 }
613
614 Some(ping_if_older_than) => {
615 debug!(target: "rustydht_lib::DHT", "Sending pings to all nodes that have never verified or haven't been verified in a while");
616 let (unverified, verified) = {
617 let state = self.state.lock().unwrap();
618 (
619 state.buckets.get_all_unverified(),
620 state.buckets.get_all_verified(),
621 )
622 };
623 for wrapper in unverified {
625 if let Some(last_verified) = wrapper.last_verified {
627 if last_verified >= ping_if_older_than {
628 continue;
629 }
630 trace!(target: "rustydht_lib::DHT", "Sending ping to reverify backup {:?}", wrapper.node);
631 } else {
632 trace!(target: "rustydht_lib::DHT",
633 "Sending ping to verify {:?} (last seen {} seconds ago)",
634 wrapper.node,
635 (Instant::now() - wrapper.last_seen).as_secs()
636 );
637 }
638 self.ping_internal(
639 shutdown.clone(),
640 wrapper.node.address,
641 Some(wrapper.node.id),
642 )
643 .await?;
644 }
645
646 for wrapper in verified {
648 if let Some(last_verified) = wrapper.last_verified {
649 if last_verified >= ping_if_older_than {
650 continue;
651 }
652 }
653 trace!(target: "rustydht_lib::DHT", "Sending ping to reverify {:?}", wrapper.node);
654 self.ping_internal(
655 shutdown.clone(),
656 wrapper.node.address,
657 Some(wrapper.node.id),
658 )
659 .await?;
660 }
661 }
662 }
663 }
664 }
665
666 async fn periodic_find_node(
667 &self,
668 shutdown: shutdown::ShutdownReceiver,
669 ) -> Result<(), RustyDHTError> {
670 loop {
671 let find_node_interval_secs =
672 self.state.lock().unwrap().settings.find_nodes_interval_secs;
673 sleep(Duration::from_secs(find_node_interval_secs)).await;
674
675 let (count_unverified, count_verified) = self.state.lock().unwrap().buckets.count();
676
677 if count_verified == 0 {
680 self.ping_routers(shutdown.clone()).await?;
681 }
682
683 let (nearest_nodes, id_near_us) = {
685 let state = self.state.lock().unwrap();
686 if count_unverified > state.settings.find_nodes_skip_count {
687 debug!(target: "rustydht_lib::DHT", "Skipping find_node as we already have enough unverified");
688 continue;
689 }
690
691 let id_near_us = state.our_id.make_mutant(4).unwrap();
692
693 (
695 state.buckets.get_nearest_nodes(&id_near_us, None),
696 id_near_us,
697 )
698 };
699 trace!(
700 target: "rustydht_lib::DHT",
701 "Sending find_node to {} nodes about {:?}",
702 nearest_nodes.len(),
703 id_near_us
704 );
705 for node in nearest_nodes {
706 self.find_node_internal(shutdown.clone(), node.address, Some(node.id), id_near_us)
707 .await?;
708 }
709 }
710 }
711
712 async fn periodic_ip4_maintenance(&self) -> Result<(), RustyDHTError> {
713 loop {
714 sleep(Duration::from_secs(10)).await;
715
716 let mut state = self.state.lock().unwrap();
717 state.ip4_source.decay();
718
719 if let Some(ip) = state.ip4_source.get_best_ipv4() {
720 let ip = IpAddr::V4(ip);
721 if !state.our_id.is_valid_for_ip(&ip) {
722 let new_id = Id::from_ip(&ip);
723 info!(target: "rustydht_lib::DHT",
724 "Our current id {} is not valid for IP {}. Using new id {}",
725 state.our_id,
726 ip,
727 new_id
728 );
729 state.our_id = new_id;
730 state.buckets.set_id(new_id);
731 }
732 }
733 }
734 }
735
736 async fn periodic_router_ping(
737 &self,
738 shutdown: shutdown::ShutdownReceiver,
739 ) -> Result<(), RustyDHTError> {
740 loop {
741 let router_ping_interval_secs = self
742 .state
743 .lock()
744 .unwrap()
745 .settings
746 .router_ping_interval_secs;
747 sleep(Duration::from_secs(router_ping_interval_secs)).await;
748 debug!(target: "rustydht_lib::DHT", "Pinging routers");
749 let shutdown_clone = shutdown.clone();
750 self.ping_routers(shutdown_clone).await?;
751 }
752 }
753
754 async fn periodic_token_rotation(&self) -> Result<(), RustyDHTError> {
755 loop {
756 sleep(Duration::from_secs(300)).await;
757 self.rotate_token_secrets();
758 }
759 }
760
761 async fn ping_internal(
763 &self,
764 shutdown: shutdown::ShutdownReceiver,
765 target: SocketAddr,
766 target_id: Option<Id>,
767 ) -> Result<(), RustyDHTError> {
768 let state = self.state.clone();
769 let socket = self.socket.clone();
770 shutdown::ShutdownReceiver::spawn_with_shutdown(
771 shutdown,
772 async move {
773 let req = {
774 let state = state.lock().unwrap();
775 MessageBuilder::new_ping_request()
776 .sender_id(state.our_id)
777 .read_only(state.settings.read_only)
778 .build()
779 .expect("Failed to build ping packet")
780 };
781
782 if let Err(e) =
783 DHT::common_send_and_handle_response(state, socket, req, target, target_id)
784 .await
785 {
786 match e {
787 RustyDHTError::TimeoutError(e) => {
788 debug!(target: "rustydht_lib::DHT", "Ping timed out: {}", e);
789 }
790
791 _ => {
792 error!(target: "rustydht_lib::DHT", "Error during ping: {}", e);
793 }
794 }
795 }
796 },
797 format!("ping to {}", target),
798 Some(Duration::from_secs(5)),
799 );
800 Ok(())
801 }
802
803 async fn common_send_and_handle_response(
811 state: Arc<Mutex<DHTState>>,
812 socket: Arc<DHTSocket>,
813 msg: packets::Message,
814 target: SocketAddr,
815 target_id: Option<Id>,
816 ) -> Result<packets::Message, RustyDHTError> {
817 if !matches!(msg.message_type, packets::MessageType::Request(_)) {
818 return Err(RustyDHTError::GeneralError(anyhow!(
819 "This method is only for sending requests"
820 )));
821 }
822
823 let maybe_receiver = socket.send_to(msg.clone(), target, target_id).await?;
824 match maybe_receiver {
825 Some(mut receiver) => match receiver.recv().await {
826 Some(reply) => match &reply.message_type {
827 packets::MessageType::Response(response_variant) => {
828 let their_id =
831 reply.get_author_id().expect("response doesn't have Id!?");
832 let id_is_valid = their_id.is_valid_for_ip(&target.ip());
833
834 if id_is_valid {
837 let mut state = state.lock().unwrap();
838 DHT::ip4_vote_helper(&mut state, &target, &reply);
839 state
840 .buckets
841 .add_or_update(Node::new(their_id, target), true);
842 }
843
844 if let packets::ResponseSpecific::FindNodeResponse(args) = response_variant {
848 let mut state = state.lock().unwrap();
849 for node in &args.nodes {
850 if node.id.is_valid_for_ip(&node.address.ip()) {
851 state.buckets.add_or_update(node.clone(), false);
852 }
853 }
854 }
855
856 Ok(reply)
857 }
858
859 _ => Err(RustyDHTError::GeneralError(anyhow!("Received wrong Message type as response from {}. {:?}", target, reply)))
860 },
861
862 None => Err(RustyDHTError::TimeoutError(anyhow!("Response channel was cleaned up while we were waiting for a reply from {}. Message we sent: {:?}", target, msg)))
863 },
864
865 None => Err(RustyDHTError::GeneralError(anyhow!("Didn't get a response channel after sending a request to {}. We sent: {:?}", target, msg)))
866 }
867 }
868
869 async fn ping_router<G: AsRef<str>>(
870 &self,
871 shutdown: shutdown::ShutdownReceiver,
872 hostname: G,
873 ) -> Result<(), RustyDHTError> {
874 let hostname = hostname.as_ref();
875 let resolve = lookup_host(hostname).await;
877 if let Err(err) = resolve {
878 warn!(
882 target: "rustydht_lib::DHT",
883 "Failed to resolve host {} due to error {:#?}. Try again later.",
884 hostname, err
885 );
886 return Ok(());
887 }
888
889 for socket_addr in resolve.unwrap() {
890 if socket_addr.is_ipv4() {
891 let shutdown_clone = shutdown.clone();
892 self.ping_internal(shutdown_clone, socket_addr, None)
893 .await?;
894 break;
895 }
896 }
897 Ok(())
898 }
899
900 async fn ping_routers(
902 &self,
903 shutdown: shutdown::ShutdownReceiver,
904 ) -> Result<(), RustyDHTError> {
905 let mut futures = futures::stream::FuturesUnordered::new();
906 let routers = self.state.lock().unwrap().settings.routers.clone();
907 for hostname in routers {
908 let shutdown_clone = shutdown.clone();
909 futures.push(self.ping_router(shutdown_clone, hostname));
910 }
911 while let Some(result) = futures.next().await {
912 result?;
913 }
914 Ok(())
915 }
916
917 fn rotate_token_secrets(&self) {
918 let mut state = self.state.lock().unwrap();
919 let new_token_secret = make_token_secret(state.settings.token_secret_size);
920
921 state.old_token_secret = state.token_secret.clone();
922 state.token_secret = new_token_secret;
923 debug!(
924 target: "rustydht_lib::DHT",
925 "Rotating token secret. New secret is {:?}, old secret is {:?}",
926 state.token_secret,
927 state.old_token_secret
928 );
929 }
930
931 async fn find_node_internal(
932 &self,
933 shutdown: shutdown::ShutdownReceiver,
934 dest: SocketAddr,
935 dest_id: Option<Id>,
936 target: Id,
937 ) -> Result<(), RustyDHTError> {
938 let state = self.state.clone();
939 let socket = self.socket.clone();
940 shutdown::ShutdownReceiver::spawn_with_shutdown(
941 shutdown,
942 async move {
943 let req = {
944 let state = state.lock().unwrap();
945 MessageBuilder::new_find_node_request()
946 .sender_id(state.our_id)
947 .read_only(state.settings.read_only)
948 .target(target)
949 .build()
950 .expect("Failed to build ping packet")
951 };
952
953 if let Err(e) =
954 DHT::common_send_and_handle_response(state, socket, req, dest, dest_id).await
955 {
956 match e {
957 RustyDHTError::TimeoutError(e) => {
958 debug!(target: "rustydht_lib::DHT", "find_node timed out: {}", e);
959 }
960
961 _ => {
962 error!(target: "rustydht_lib::DHT", "Error during find_node: {}", e);
963 }
964 }
965 }
966 },
967 format!("find_node to {} for {}", dest, target),
968 Some(Duration::from_secs(5)),
969 );
970 Ok(())
971 }
972
973 fn ip4_vote_helper(state: &mut DHTState, addr: &SocketAddr, msg: &packets::Message) {
975 if let IpAddr::V4(their_ip) = addr.ip() {
976 if let Some(SocketAddr::V4(they_claim_our_sockaddr)) = &msg.requester_ip {
977 state
978 .ip4_source
979 .add_vote(their_ip, *they_claim_our_sockaddr.ip());
980 }
981 }
982 }
983}
984
985fn calculate_token<T: AsRef<[u8]>>(remote: &SocketAddr, secret: T) -> [u8; 4] {
989 let secret = secret.as_ref();
990 let mut digest = crc32::Digest::new(crc32::CASTAGNOLI);
991 let octets = match remote.ip() {
993 std::net::IpAddr::V4(v4) => v4.octets().to_vec(),
994 std::net::IpAddr::V6(v6) => v6.octets().to_vec(),
995 };
996 digest.write(&octets);
997 digest.write(secret);
998 let checksum: u32 = digest.sum32();
999
1000 checksum.to_be_bytes()
1001}
1002
1003fn make_token_secret(size: usize) -> Vec<u8> {
1004 let mut token_secret = vec![0; size];
1005 token_secret.fill_with(|| thread_rng().gen());
1006 token_secret
1007}
1008
1009#[cfg(test)]
1010mod test {
1011 use super::*;
1012 use crate::common::ipv4_addr_src::StaticIPV4AddrSource;
1013 use crate::dht::DHTBuilder;
1014 use crate::dht::DHTSettingsBuilder;
1015 use anyhow::anyhow;
1016 use std::boxed::Box;
1017 use std::net::{Ipv4Addr, SocketAddrV4};
1018
1019 async fn make_test_dht(
1020 port: u16,
1021 ) -> (DHT, shutdown::ShutdownSender, shutdown::ShutdownReceiver) {
1022 let ipv4 = Ipv4Addr::new(1, 2, 3, 4);
1023 let phony_ip4 = Box::new(StaticIPV4AddrSource::new(ipv4));
1024 let (tx, rx) = shutdown::create_shutdown();
1025 (
1026 DHTBuilder::new()
1027 .initial_id(get_dht_id())
1028 .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
1029 .ip_source(phony_ip4)
1030 .build(rx.clone())
1031 .unwrap(),
1032 tx,
1033 rx,
1034 )
1035 }
1036
1037 #[tokio::test]
1038 async fn test_responds_to_ping() -> Result<(), RustyDHTError> {
1039 let requester_id = Id::from_random(&mut thread_rng());
1040 let ping_request = MessageBuilder::new_ping_request()
1041 .sender_id(requester_id)
1042 .build()?;
1043
1044 let port = 1948;
1045 let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1046 shutdown::ShutdownReceiver::spawn_with_shutdown(
1047 shutdown_rx,
1048 async move {
1049 dht.run_event_loop().await.unwrap();
1050 },
1051 "Test DHT",
1052 Some(Duration::from_secs(10)),
1053 );
1054
1055 let res = send_and_receive(ping_request.clone(), port).await.unwrap();
1056
1057 assert_eq!(res.transaction_id, ping_request.transaction_id);
1058 assert_eq!(
1059 res.message_type,
1060 packets::MessageType::Response(packets::ResponseSpecific::PingResponse(
1061 packets::PingResponseArguments {
1062 responder_id: get_dht_id()
1063 }
1064 ))
1065 );
1066
1067 shutdown_tx.shutdown().await;
1068
1069 Ok(())
1070 }
1071
1072 #[tokio::test]
1073 async fn test_responds_to_get_peers() -> Result<(), RustyDHTError> {
1074 let requester_id = Id::from_random(&mut thread_rng());
1075 let desired_info_hash = Id::from_random(&mut thread_rng());
1076 let request = MessageBuilder::new_get_peers_request()
1077 .sender_id(requester_id)
1078 .target(desired_info_hash)
1079 .build()?;
1080
1081 let port = 1974;
1082 let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1083 shutdown::ShutdownReceiver::spawn_with_shutdown(
1084 shutdown_rx,
1085 async move {
1086 dht.run_event_loop().await.unwrap();
1087 },
1088 "Test DHT",
1089 Some(Duration::from_secs(10)),
1090 );
1091
1092 let res = send_and_receive(request.clone(), port).await.unwrap();
1093
1094 assert_eq!(res.transaction_id, request.transaction_id);
1095 assert!(matches!(
1096 res.message_type,
1097 packets::MessageType::Response(packets::ResponseSpecific::GetPeersResponse(
1098 packets::GetPeersResponseArguments { .. }
1099 ))
1100 ));
1101
1102 shutdown_tx.shutdown().await;
1103
1104 Ok(())
1105 }
1106
1107 #[tokio::test]
1108 async fn test_responds_to_find_node() -> Result<(), RustyDHTError> {
1109 let port = 1995;
1110 let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1111 shutdown::ShutdownReceiver::spawn_with_shutdown(
1112 shutdown_rx,
1113 async move {
1114 dht.run_event_loop().await.unwrap();
1115 },
1116 "Test DHT",
1117 Some(Duration::from_secs(10)),
1118 );
1119
1120 let requester_id = Id::from_random(&mut thread_rng());
1121 let target = Id::from_random(&mut thread_rng());
1122 let request = MessageBuilder::new_find_node_request()
1123 .sender_id(requester_id)
1124 .target(target)
1125 .build()?;
1126 let res = send_and_receive(request.clone(), port).await.unwrap();
1127
1128 assert_eq!(res.transaction_id, request.transaction_id);
1129 assert!(matches!(
1130 res.message_type,
1131 packets::MessageType::Response(packets::ResponseSpecific::FindNodeResponse(
1132 packets::FindNodeResponseArguments { .. }
1133 ))
1134 ));
1135
1136 shutdown_tx.shutdown().await;
1137
1138 Ok(())
1139 }
1140
1141 #[tokio::test]
1142 async fn test_responds_to_announce_peer() -> Result<(), RustyDHTError> {
1143 let requester_id = Id::from_random(&mut thread_rng());
1144 let info_hash = Id::from_random(&mut thread_rng());
1145 let port = 2014;
1146 let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1147 shutdown::ShutdownReceiver::spawn_with_shutdown(
1148 shutdown_rx,
1149 async move {
1150 dht.run_event_loop().await.unwrap();
1151 },
1152 "Test DHT",
1153 Some(Duration::from_secs(10)),
1154 );
1155
1156 let reply = send_and_receive(
1158 MessageBuilder::new_get_peers_request()
1159 .sender_id(requester_id)
1160 .target(info_hash)
1161 .build()?,
1162 port,
1163 )
1164 .await
1165 .unwrap();
1166
1167 let token = {
1169 if let packets::MessageType::Response(packets::ResponseSpecific::GetPeersResponse(
1170 packets::GetPeersResponseArguments { token, .. },
1171 )) = reply.message_type
1172 {
1173 token
1174 } else {
1175 return Err(RustyDHTError::GeneralError(anyhow!("Didn't get token")));
1176 }
1177 };
1178
1179 let reply = send_and_receive(
1181 MessageBuilder::new_announce_peer_request()
1182 .sender_id(requester_id)
1183 .target(info_hash)
1184 .port(1234)
1185 .token(token)
1186 .build()?,
1187 port,
1188 )
1189 .await
1190 .unwrap();
1191
1192 assert!(matches!(
1194 reply.message_type,
1195 packets::MessageType::Response(packets::ResponseSpecific::PingResponse(
1196 packets::PingResponseArguments { .. }
1197 ))
1198 ));
1199
1200 let reply = send_and_receive(
1202 MessageBuilder::new_get_peers_request()
1203 .sender_id(requester_id)
1204 .target(info_hash)
1205 .build()?,
1206 port,
1207 )
1208 .await
1209 .unwrap();
1210
1211 eprintln!("Received {:?}", reply);
1212
1213 let peers = {
1215 if let packets::MessageType::Response(packets::ResponseSpecific::GetPeersResponse(
1216 packets::GetPeersResponseArguments {
1217 values: packets::GetPeersResponseValues::Peers(p),
1218 ..
1219 },
1220 )) = reply.message_type
1221 {
1222 p
1223 } else {
1224 return Err(RustyDHTError::GeneralError(anyhow!("Didn't get peers")));
1225 }
1226 };
1227 assert_eq!(peers.len(), 1);
1228 assert_eq!(peers[0].port(), 1234);
1229 eprintln!("all good!");
1230 shutdown_tx.shutdown().await;
1231
1232 Ok(())
1233 }
1234
1235 #[tokio::test]
1236 async fn test_responds_to_sample_infohashes() -> Result<(), RustyDHTError> {
1237 let requester_id = Id::from_random(&mut thread_rng());
1238 let target = Id::from_random(&mut thread_rng());
1239 let request = MessageBuilder::new_sample_infohashes_request()
1240 .sender_id(requester_id)
1241 .target(target)
1242 .build()?;
1243
1244 let port = 2037;
1245 let (dht, mut shutdown_tx, shutdown_rx) = make_test_dht(port).await;
1246 shutdown::ShutdownReceiver::spawn_with_shutdown(
1247 shutdown_rx,
1248 async move {
1249 dht.run_event_loop().await.unwrap();
1250 },
1251 "Test DHT",
1252 Some(Duration::from_secs(10)),
1253 );
1254
1255 let res = send_and_receive(request.clone(), port).await.unwrap();
1256
1257 assert_eq!(res.transaction_id, request.transaction_id);
1258 assert!(matches!(
1259 res.message_type,
1260 packets::MessageType::Response(packets::ResponseSpecific::SampleInfoHashesResponse(
1261 packets::SampleInfoHashesResponseArguments { num: 0, .. }
1262 ))
1263 ));
1264
1265 shutdown_tx.shutdown().await;
1266
1267 Ok(())
1268 }
1269
1270 #[tokio::test]
1271 async fn test_event_loop_pings_routers() {
1272 let (mut shutdown_tx, shutdown_rx) = shutdown::create_shutdown();
1273 let port1 = 2171;
1274 let dht1 = Arc::new(
1275 DHTBuilder::new()
1276 .initial_id(get_dht_id())
1277 .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port1))
1278 .ip_source(Box::new(StaticIPV4AddrSource::new(Ipv4Addr::new(
1279 1, 2, 3, 4,
1280 ))))
1281 .settings(DHTSettingsBuilder::new().routers(vec![]).build())
1282 .build(shutdown_rx.clone())
1283 .unwrap(),
1284 );
1285
1286 let dht2 = Arc::new(
1287 DHTBuilder::new()
1288 .initial_id(get_dht_id().make_mutant(4).unwrap())
1289 .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))
1290 .ip_source(Box::new(StaticIPV4AddrSource::new(Ipv4Addr::new(
1291 1, 2, 3, 4,
1292 ))))
1293 .settings(
1294 DHTSettingsBuilder::new()
1295 .router_ping_interval_secs(1)
1296 .routers(vec![format!("127.0.0.1:{}", port1)])
1297 .build(),
1298 )
1299 .build(shutdown_rx.clone())
1300 .unwrap(),
1301 );
1302
1303 let mut receiver = dht2.subscribe();
1304
1305 shutdown::ShutdownReceiver::spawn_with_shutdown(
1306 shutdown_rx.clone(),
1307 async move {
1308 dht1.run_event_loop().await.unwrap();
1309 },
1310 "DHT1",
1311 None,
1312 );
1313
1314 let dht2_clone = dht2.clone();
1315 shutdown::ShutdownReceiver::spawn_with_shutdown(
1316 shutdown_rx,
1317 async move { dht2_clone.run_event_loop().await.unwrap() },
1318 "DHT2",
1319 None,
1320 );
1321
1322 receiver.recv().await;
1323 let (unverified, verified) = dht2.state.lock().unwrap().buckets.count();
1324
1325 drop(dht2);
1327
1328 shutdown_tx.shutdown().await;
1329 assert_eq!(unverified, 0);
1330 assert_eq!(verified, 1);
1331 }
1332
1333 #[tokio::test]
1334 async fn test_token_secret_rotation() {
1335 let ipv4 = Ipv4Addr::new(1, 2, 3, 4);
1336 let phony_ip4 = Box::new(StaticIPV4AddrSource::new(ipv4));
1337 let port = 2244;
1338
1339 let dht = DHTBuilder::new()
1340 .initial_id(get_dht_id())
1341 .listen_addr(SocketAddrV4::new(Ipv4Addr::LOCALHOST, port))
1342 .ip_source(phony_ip4)
1343 .settings(DHTSettingsBuilder::new().routers(vec![]).build())
1344 .build(shutdown::create_shutdown().1)
1345 .unwrap();
1346
1347 assert_eq!(
1348 dht.state.lock().unwrap().token_secret.len(),
1349 DHTSettings::default().token_secret_size
1350 );
1351
1352 dht.rotate_token_secrets();
1353 assert_eq!(
1354 dht.state.lock().unwrap().old_token_secret.len(),
1355 DHTSettings::default().token_secret_size
1356 );
1357 assert_eq!(
1358 dht.state.lock().unwrap().token_secret.len(),
1359 DHTSettings::default().token_secret_size
1360 );
1361
1362 let state = dht.state.lock().unwrap();
1363 assert_ne!(state.old_token_secret, state.token_secret);
1364 }
1365
1366 fn get_dht_id() -> Id {
1368 Id::from_hex("0011223344556677889900112233445566778899").unwrap()
1369 }
1370
1371 async fn send_and_receive(
1373 msg: packets::Message,
1374 port: u16,
1375 ) -> Result<packets::Message, RustyDHTError> {
1376 let sock = UdpSocket::bind("127.0.0.1:0").await.unwrap();
1377 sock.send_to(
1378 &msg.clone().to_bytes().unwrap(),
1379 format!("127.0.0.1:{}", port),
1380 )
1381 .await
1382 .unwrap();
1383 let mut recv_buf = [0; 2048];
1384 let num_read = sock.recv_from(&mut recv_buf).await.unwrap().0;
1385 packets::Message::from_bytes(&recv_buf[..num_read])
1386 }
1387}