1use std::{
2 cell::RefCell,
3 net::{IpAddr, SocketAddr},
4 path::Path,
5 rc::Rc,
6 time::Duration,
7};
8
9use anyhow::Context;
10use futures::future::join_all;
11use krpc::{
12 protocol::{FindNodesResponse, GetPeers, GetPeersResponseBody, Ping},
13 setup_krpc, KrpcClient, KrpcServer,
14};
15use sha1::{Digest, Sha1};
16use slotmap::Key;
17use time::OffsetDateTime;
18use token_store::TokenStore;
19use tokio::sync::Notify;
20
21use crate::{
22 krpc::protocol::{serialize_compact_peers, AnnouncePeer, Answer, FindNodes, Query},
23 node::{Node, NodeId, NodeStatus, ID_MAX, ID_ZERO},
24 routing_table::{BucketId, RoutingTable, BUCKET_SIZE},
25};
26
27mod krpc;
28mod node;
29mod routing_table;
30mod token_store;
31
32const BOOTSTRAP_NODES: [&str; 5] = [
33 "router.bittorrent.com:6881",
34 "dht.transmissionbt.com:6881",
35 "router.bitcomet.com:6881",
36 "dht.aelitis.com:6881",
37 "bootstrap.jami.net:4222",
38];
39
40#[inline]
41fn generate_node_id() -> NodeId {
42 let id = rand::random::<[u8; 20]>();
43 let mut hasher = Sha1::new();
44 hasher.update(id);
45 NodeId::from(hasher.finalize().as_slice())
46}
47
48#[inline]
50fn load_table(path: &Path) -> Option<RoutingTable> {
51 serde_json::from_reader(std::fs::File::open(path).ok()?).ok()?
52}
53
54pub trait PeerProvider {
55 fn get_peers(&self, info_hash: [u8; 20]) -> Option<Vec<SocketAddr>>;
56 fn insert_peer(&self, info_hash: [u8; 20], peer: SocketAddr);
57}
58
59const REFRESH_TIMEOUT: Duration = Duration::from_secs(15 * 60);
64
65#[derive(Clone)]
67pub struct Dht {
68 krpc_client: KrpcClient,
69 krpc_server: KrpcServer,
70 routing_table: Rc<RefCell<RoutingTable>>,
71 token_store: TokenStore,
72 port: u16,
73 node_added_notify: Rc<Notify>,
75}
76
77impl Dht {
78 pub async fn new(
79 bind_addr: SocketAddr,
80 peer_provider: impl PeerProvider + 'static,
81 ) -> anyhow::Result<Self> {
82 let port = bind_addr.port();
83 let (client, server) = setup_krpc(bind_addr).await?;
84 if let Some(table) = load_table(Path::new("routing_table.json")) {
85 log::info!("Loading existing table");
86 let bucket_ids: Vec<_> = table.bucket_ids().collect();
87 let dht = Dht {
88 krpc_client: client,
89 krpc_server: server,
90 routing_table: Rc::new(RefCell::new(table)),
91 node_added_notify: Rc::new(Notify::new()),
92 token_store: TokenStore::new(),
93 port,
94 };
95 for bucket_id in bucket_ids {
96 dht.schedule_refresh(bucket_id);
97 }
98 dht.handle_incoming(peer_provider);
99 Ok(dht)
100 } else {
101 let node_id = generate_node_id();
102 let routing_table = RoutingTable::new(node_id);
103
104 let dht = Dht {
105 krpc_client: client,
106 krpc_server: server,
107 routing_table: Rc::new(RefCell::new(routing_table)),
108 node_added_notify: Rc::new(Notify::new()),
109 token_store: TokenStore::new(),
110 port,
111 };
112
113 log::info!("Bootstrapping");
114 dht.handle_incoming(peer_provider);
115 dht.bootstrap().await?;
116 log::info!("Bootstrap successful");
117
118 Ok(dht)
119 }
120 }
121
122 pub fn find_peers(&self, info_hash: &[u8]) -> tokio::sync::mpsc::Receiver<Vec<SocketAddr>> {
123 let this = self.clone();
130 let (tx, rc) = tokio::sync::mpsc::channel(64);
131 let info_hash = NodeId::from(info_hash);
132 tokio_uring::spawn(async move {
133 while !tx.is_closed() {
134 log::debug!("Start search for peers");
135 this.search(info_hash, true).await.unwrap();
136 let nodes = this
137 .routing_table
138 .borrow()
139 .get_k_closest(BUCKET_SIZE, &info_hash);
140 this.get_peers_from_nodes(&info_hash, &nodes, tx.clone())
141 .await
142 .unwrap();
143 let own_id = this.routing_table.borrow().own_id;
144 for node in nodes {
148 if let IpAddr::V4(addr) = node.addr.ip() {
149 if let Some(token) = this.token_store.get_token(addr) {
150 if this
151 .krpc_client
152 .announce_peer(AnnouncePeer {
153 id: own_id,
154 info_hash: *info_hash,
155 implied_port: true,
156 port: this.port,
157 token: serde_bytes::ByteBuf::from(token),
158 })
159 .with_timeout(Duration::from_secs(3))
160 .send(node.addr)
161 .await
162 .is_err()
163 {
164 log::error!("Announce failed!");
165 }
166 } else {
167 log::warn!("Token not found for: {addr}");
168 }
169 } else {
170 panic!("Tokens may only be stored for nodes with Ipv4 addrs.");
171 }
172 }
173 log::debug!("Waiting for notify more nodes");
174 let _ = tokio::time::timeout(
176 Duration::from_secs(30),
177 this.node_added_notify.notified(),
178 )
179 .await;
180 }
181 });
182 rc
183 }
184
185 fn handle_incoming(&self, peer_provider: impl PeerProvider + 'static) {
186 let this = self.clone();
187 self.krpc_server.serve(move |mut addr, query| {
188 log::debug!("Received query: {query:?}");
189 let our_id = this.routing_table.borrow().own_id;
190 let ip = match addr.ip() {
191 IpAddr::V4(ip) => ip,
192 IpAddr::V6(_) => {
193 log::error!("Ip v6 addresses aren't supported for token generation");
194 return Err(krpc::error::KrpcError::generic(
195 "Ip v6 addresses aren't supported for token generation".to_owned(),
196 ));
197 }
198 };
199 match query {
200 Query::FindNode { id: _, target } => {
201 let target = NodeId::from(target.as_slice());
202 let closet = this
203 .routing_table
204 .borrow()
205 .get_k_closest(BUCKET_SIZE, &target);
206 log::debug!("Found: {} nodes closet to {target:?}", closet.len());
207 Ok(Answer::FindNode {
208 id: serde_bytes::ByteBuf::from(our_id.as_bytes()),
209 nodes: krpc::protocol::serialize_compact_nodes(&closet),
210 })
211 }
212 Query::GetPeers { id: _, info_hash } => {
213 if let Ok(info_hash) = info_hash.as_slice().try_into() {
214 if let Some(peers) = peer_provider.get_peers(info_hash) {
215 Ok(Answer::GetPeers {
216 id: serde_bytes::ByteBuf::from(our_id.as_bytes()),
217 token: serde_bytes::ByteBuf::from(
218 this.token_store.generate(ip).to_vec(),
219 ),
220 values: Some(serialize_compact_peers(&peers)),
221 nodes: None,
222 })
223 } else {
224 let target = NodeId::from(info_hash.as_slice());
225 let closet = this
226 .routing_table
227 .borrow()
228 .get_k_closest(BUCKET_SIZE, &target);
229 log::debug!("Found: {} nodes closet to {target:?}", closet.len());
230
231 Ok(Answer::GetPeers {
232 id: serde_bytes::ByteBuf::from(our_id.as_bytes()),
233 token: serde_bytes::ByteBuf::from(
234 this.token_store.generate(ip).to_vec(),
235 ),
236 values: None,
237 nodes: Some(krpc::protocol::serialize_compact_nodes(&closet)),
238 })
239 }
240 } else {
241 Err(krpc::error::KrpcError::protocol(
242 "Invalid infohash".to_owned(),
243 ))
244 }
245 }
246 Query::AnnouncePeer {
247 id: _,
248 implied_port,
249 info_hash,
250 port,
251 token,
252 } => {
253 if this
254 .token_store
255 .validate(ip, bytes::Bytes::copy_from_slice(&token))
256 {
257 log::info!("Recived valid announce peer request");
258 if !implied_port {
259 addr.set_port(port);
260 }
261 if let Ok(info_hash) = info_hash.as_slice().try_into() {
262 peer_provider.insert_peer(info_hash, addr);
263 Ok(Answer::QueriedNodeId {
264 id: serde_bytes::ByteBuf::from(our_id.as_bytes()),
265 })
266 } else {
267 Err(krpc::error::KrpcError::protocol(
268 "Invalid infohash".to_owned(),
269 ))
270 }
271 } else {
272 Err(krpc::error::KrpcError::protocol("Invalid token".to_owned()))
273 }
274 }
275 Query::Ping { id: _ } => {
276 Ok(Answer::QueriedNodeId {
279 id: serde_bytes::ByteBuf::from(our_id.as_bytes()),
280 })
281 }
282 }
283 });
284 }
285
286 pub async fn save(&self, path: &Path) -> anyhow::Result<()> {
287 log::info!("Saving table");
288 let table_json = {
289 let routing_table = self.routing_table.borrow();
290 serde_json::to_vec(&*routing_table)?
291 };
292 let file = tokio_uring::fs::File::create(&path).await?;
293 Ok(())
296 }
297
298 async fn bootstrap(&self) -> anyhow::Result<()> {
299 let our_id = self.routing_table.borrow().own_id;
300 log::debug!("Resolving bootstrap node addrs");
301 let resolve_result = join_all(
302 BOOTSTRAP_NODES
303 .iter()
304 .map(|node_addr| tokio_uring::spawn(tokio::net::lookup_host(node_addr))),
305 )
306 .await;
307
308 log::debug!("Pinging bootstrap nodes");
309 let bootstrap_ping_futures = resolve_result
310 .into_iter()
311 .filter_map(|result| result.map(|inner| inner.ok()).ok().flatten())
313 .filter_map(|mut node_addrs| node_addrs.next())
315 .map(|addr| Node {
316 id: ID_ZERO,
317 addr,
318 last_seen: OffsetDateTime::now_utc(),
319 last_status: NodeStatus::Unknown,
320 })
321 .map(|node| async move {
322 log::debug!("Pinging {}", node.addr);
323 let result = self
324 .krpc_client
325 .ping(Ping { id: our_id })
326 .send(node.addr)
327 .await;
328 (node, result)
329 });
330
331 let mut any_success = false;
332 for (mut node, result) in join_all(bootstrap_ping_futures).await {
333 if let Ok(pong) = result {
334 node.last_seen = OffsetDateTime::now_utc();
335 node.last_status = NodeStatus::Good;
336 node.id = pong.id;
337 log::debug!("Node {} responded", node.addr);
338 assert!(self.insert_node(node, None).await);
339 any_success = true;
340 }
341 }
342
343 if !any_success {
344 anyhow::bail!("Bootstrap failed, node not responsive");
345 } else {
346 Ok(())
347 }
348 }
349
350 async fn refresh_bucket(&self, bucket_id: BucketId) {
353 let (is_full, mut candiate, our_id, id) = {
354 let mut routing_table = self.routing_table.borrow_mut();
355 let our_id = routing_table.own_id;
356
357 let Some(bucket) = routing_table.get_bucket_mut(bucket_id) else {
358 log::error!("Bucket with id {bucket_id:?} to be refreshed no longer exist");
359 return;
360 };
361 if OffsetDateTime::now_utc() - bucket.last_changed() < REFRESH_TIMEOUT {
364 log::debug!(
365 "Refresh task for bucket with id {bucket_id:?} is stale, skipping refresh"
366 );
367 return;
368 }
369
370 log::debug!("Refreshing bucket: {bucket_id:?}");
371 bucket.update_last_changed();
372
373 let id = bucket.random_id();
375
376 let is_full = !bucket.covers(&our_id) && bucket.is_full();
377
378 let Some(candiate) = bucket
379 .nodes_mut()
380 .filter(|node| node.last_status != NodeStatus::Bad)
381 .min_by_key(|node| id.distance(&node.id)) else {
382 log::error!("No nodes left in bucket, refresh failed");
384 return;
385 };
386 (is_full, candiate.clone(), our_id, id)
387 };
388
389 let mut need_refresh_scheduled = true;
390 if is_full {
392 match self
393 .krpc_client
394 .ping(Ping { id: our_id })
395 .with_timeout(Duration::from_secs(3))
396 .send(candiate.addr)
397 .await
398 {
399 Ok(_) => {
400 candiate.last_seen = OffsetDateTime::now_utc();
401 candiate.last_status = NodeStatus::Good;
402 }
403 Err(krpc::error::Error::Timeout(_)) if candiate.last_status == NodeStatus::Good => {
404 candiate.last_status = NodeStatus::Unknown;
405 }
406 Err(krpc::error::Error::Timeout(_))
407 if candiate.last_status == NodeStatus::Unknown =>
408 {
409 candiate.last_status = NodeStatus::Bad;
410 }
411 Err(_err) if candiate.last_status == NodeStatus::Good => {
412 candiate.last_status = NodeStatus::Unknown;
413 candiate.last_seen = OffsetDateTime::now_utc();
414 }
415 Err(_err) => {
416 candiate.last_status = NodeStatus::Bad;
417 candiate.last_seen = OffsetDateTime::now_utc();
418 }
419 }
420 if let Some(node) = self.routing_table.borrow_mut().get_mut(&candiate.id) {
423 *node = candiate;
424 }
425 } else {
426 match self
430 .krpc_client
431 .find_nodes(FindNodes {
432 id: our_id,
433 target: id,
434 })
435 .with_timeout(Duration::from_secs(3))
436 .send(candiate.addr)
437 .await
438 {
439 Ok(FindNodesResponse { id: _, nodes }) => {
440 candiate.last_seen = OffsetDateTime::now_utc();
441 candiate.last_status = NodeStatus::Good;
442 if let Some(node) = self.routing_table.borrow_mut().get_mut(&candiate.id) {
445 *node = candiate;
446 }
447 for node in nodes {
448 if self.insert_node(node.clone(), None).await {
449 need_refresh_scheduled = false;
451 log::debug!("Refreshed bucket found new node: {:?}", node.id);
452 }
453 }
454 }
455 Err(err) => {
456 match (err, candiate.last_status) {
457 (krpc::error::Error::Timeout(_), NodeStatus::Good) => {
458 candiate.last_status = NodeStatus::Unknown;
459 }
460 (krpc::error::Error::Timeout(_), NodeStatus::Unknown) => {
461 candiate.last_status = NodeStatus::Bad;
462 }
463 (krpc::error::Error::IoError(err), _) => {
464 log::warn!("Socket failure: {err}");
465 }
466 (_, NodeStatus::Good) => {
468 candiate.last_status = NodeStatus::Unknown;
469 candiate.last_seen = OffsetDateTime::now_utc();
470 }
471 (_, _) => {
472 candiate.last_status = NodeStatus::Bad;
473 candiate.last_seen = OffsetDateTime::now_utc();
475 }
476 }
477 if let Some(node) = self.routing_table.borrow_mut().get_mut(&candiate.id) {
480 *node = candiate;
481 }
482 }
483 }
484 }
485 if need_refresh_scheduled {
486 self.schedule_refresh(bucket_id);
488 }
489 }
490
491 fn schedule_refresh(&self, bucket_id: BucketId) {
492 assert!(!bucket_id.is_null());
493 let this = self.clone();
494 tokio_uring::spawn(async move {
496 log::debug!("Spawning refresh task for bucket with id {bucket_id:?}");
497 tokio::time::sleep(REFRESH_TIMEOUT).await;
498 this.refresh_bucket(bucket_id).await;
499 });
500 }
501
502 async fn insert_node(&self, node: Node, target_id: Option<NodeId>) -> bool {
503 let mut inserted = false;
504 let (our_id, updated_buckets @ [bucket_id_one, bucket_id_two]) = {
505 let mut routing_table = self.routing_table.borrow_mut();
506 (
507 routing_table.own_id,
508 routing_table.insert_node(node.clone()),
509 )
510 };
511
512 let failed_insert = bucket_id_one.is_null() && bucket_id_two.is_null();
513
514 if failed_insert {
515 let (bucket_id, mut unknown_nodes) = self.find_bucket_unknown_nodes(&node.id).unwrap();
517
518 while unknown_nodes
519 .iter()
520 .any(|node| node.current_status() == NodeStatus::Unknown)
521 {
522 for mut unknown_node in unknown_nodes
525 .iter_mut()
526 .filter(|node| node.current_status() == NodeStatus::Unknown)
527 {
528 match self
529 .krpc_client
530 .ping(Ping { id: our_id })
531 .with_timeout(Duration::from_secs(3))
532 .send(unknown_node.addr)
533 .await
534 {
535 Ok(_) => {
536 unknown_node.last_seen = OffsetDateTime::now_utc();
537 unknown_node.last_status = NodeStatus::Good;
538 }
539 Err(krpc::error::Error::Timeout(_))
540 if unknown_node.last_status == NodeStatus::Good =>
541 {
542 unknown_node.last_status = NodeStatus::Unknown;
543 }
544 Err(krpc::error::Error::Timeout(_))
545 if unknown_node.last_status == NodeStatus::Unknown =>
546 {
547 unknown_node.last_status = NodeStatus::Bad;
549 }
550 Err(_err) if unknown_node.last_status == NodeStatus::Good => {
551 unknown_node.last_status = NodeStatus::Unknown;
552 unknown_node.last_seen = OffsetDateTime::now_utc();
553 }
554 Err(_err) => {
555 unknown_node.last_status = NodeStatus::Bad;
556 unknown_node.last_seen = OffsetDateTime::now_utc();
557 }
558 }
559 }
560 }
561 let mut routing_table = self.routing_table.borrow_mut();
562 let bucket = routing_table.get_bucket_mut(bucket_id).unwrap();
563 for updated_node in unknown_nodes {
565 for current_node in bucket.nodes_mut() {
566 if updated_node.id == current_node.id {
567 *current_node = updated_node.clone();
568 }
569 if current_node.last_status == NodeStatus::Bad && !inserted {
570 inserted = true;
572 *current_node = node.clone();
574 }
575 }
576 }
577 if inserted {
578 bucket.update_last_changed();
581 self.schedule_refresh(bucket_id);
582 } else if let Some(target_id) = target_id {
583 let furthest_away = bucket
587 .nodes_mut()
588 .max_by_key(|node| node.id.distance(&target_id))
589 .unwrap();
590 *furthest_away = node;
591 inserted = true;
592 bucket.update_last_changed();
595 self.schedule_refresh(bucket_id);
596 }
597 } else {
598 for bucket_id in updated_buckets {
600 if !bucket_id.is_null() {
602 self.schedule_refresh(bucket_id);
603 inserted = true;
604 }
605 }
606 }
607 if inserted {
608 self.node_added_notify.notify_one();
609 }
610 inserted
611 }
612
613 fn find_bucket_unknown_nodes(&self, target_id: &NodeId) -> Option<(BucketId, Vec<Node>)> {
614 let routing_table = self.routing_table.borrow_mut();
615 let (bucket_id, bucket) = routing_table.find_bucket(target_id)?;
616 let mut unknown_nodes: Vec<_> = bucket
617 .nodes()
618 .cloned()
619 .filter(|node| node.current_status() == NodeStatus::Unknown)
620 .collect();
621 unknown_nodes.sort_unstable_by(|a, b| a.last_seen.cmp(&b.last_seen));
622 Some((bucket_id, unknown_nodes))
623 }
624
625 async fn search(&self, target: NodeId, force_insert: bool) -> anyhow::Result<()> {
627 let mut prev_min = ID_MAX;
628 let own_id = self.routing_table.borrow().own_id;
629 loop {
630 log::info!("Searching for: {target:?}");
631 let next_to_query = self
632 .routing_table
633 .borrow_mut()
634 .get_closest_mut(&target)
635 .context("No nodes in routing table")?
636 .clone();
637
638 let distance = target.distance(&next_to_query.id);
639 if distance < prev_min {
640 let response = match self
641 .krpc_client
642 .find_nodes(FindNodes { id: own_id, target })
643 .send(next_to_query.addr)
644 .await
645 {
646 Ok(reponse) => {
647 let mut routing_table = self.routing_table.borrow_mut();
648 let queried_node = routing_table.get_mut(&next_to_query.id).unwrap();
651 queried_node.last_status = NodeStatus::Good;
652 queried_node.last_seen = OffsetDateTime::now_utc();
653 reponse
654 }
655 Err(err) => {
656 log::warn!("{next_to_query:?} find nodes query failed: {err}");
658 match next_to_query.last_status {
659 NodeStatus::Good => {
660 let mut routing_table = self.routing_table.borrow_mut();
661 let queried_node =
664 routing_table.get_mut(&next_to_query.id).unwrap();
665 queried_node.last_seen = OffsetDateTime::now_utc();
666 queried_node.last_status = NodeStatus::Unknown;
667 }
668 NodeStatus::Unknown => {
669 let mut routing_table = self.routing_table.borrow_mut();
670 let queried_node =
673 routing_table.get_mut(&next_to_query.id).unwrap();
674 queried_node.last_status = NodeStatus::Bad;
675 log::debug!("Setting status of {next_to_query:?} to bad");
676 }
677 NodeStatus::Bad => {
678 log::debug!("Removing {next_to_query:?} from table");
679 self.routing_table.borrow_mut().remove(&next_to_query)?;
680 }
681 }
682 continue;
683 }
684 };
685 log::debug!("Got nodes from: {next_to_query:?}");
686 for node in response.nodes {
687 let target = force_insert.then_some(target);
688 if self.insert_node(node, target).await {
689 log::debug!("Inserted node");
690 } else {
691 log::debug!("Did not insert node because of full routing table");
692 }
693 }
694 prev_min = distance;
695 } else {
696 break;
697 }
698 }
699 Ok(())
700 }
701
702 async fn get_peers_from_nodes(
704 &self,
705 target: &NodeId,
706 nodes: &[Node],
707 peer_listener: tokio::sync::mpsc::Sender<Vec<SocketAddr>>,
708 ) -> anyhow::Result<()> {
709 let own_id = self.routing_table.borrow().own_id;
710
711 for node in nodes {
715 let response = match self
716 .krpc_client
717 .get_peers(GetPeers {
718 id: own_id,
719 info_hash: target.as_bytes(),
720 })
721 .with_timeout(Duration::from_secs(3))
722 .send(node.addr)
723 .await
724 {
725 Ok(reponse) => {
726 let mut routing_table = self.routing_table.borrow_mut();
727 if let Some(queried_node) = routing_table.get_mut(&node.id) {
728 queried_node.last_status = NodeStatus::Good;
729 }
730 reponse
731 }
732 Err(err) => {
733 log::warn!("{node:?} ping failed: {err}");
734 match node.last_status {
735 NodeStatus::Good => {
736 let mut routing_table = self.routing_table.borrow_mut();
737 if let Some(queried_node) = routing_table.get_mut(&node.id) {
738 queried_node.last_status = NodeStatus::Unknown;
739 }
740 }
741 NodeStatus::Unknown => {
742 let mut routing_table = self.routing_table.borrow_mut();
743 if let Some(queried_node) = routing_table.get_mut(&node.id) {
744 queried_node.last_status = NodeStatus::Bad;
745 }
746 }
747 NodeStatus::Bad => {
748 let _ = self.routing_table.borrow_mut().remove(node);
749 }
750 }
751 continue;
752 }
753 };
754 if let IpAddr::V4(addr) = node.addr.ip() {
755 self.token_store
756 .store_token(addr, response.token.into_vec().into());
757 } else {
758 log::error!("Tokens may only be stored for nodes with Ipv4 addrs.");
759 }
760
761 match response.body {
762 GetPeersResponseBody::Nodes(nodes) => {
763 log::debug!("Got nodes from: {node:?}");
764 for node in nodes.into_iter() {
765 assert!(self.insert_node(node, Some(*target)).await);
766 log::debug!("Inserted node");
767 }
768 }
769 GetPeersResponseBody::Peers(peers) => {
770 log::info!("Got peers! ({})", peers.len());
771 if peer_listener.send(peers).await.is_err() {
772 log::debug!("Peer listener disconnected");
773 }
774 }
775 }
776 }
777 Ok(())
778 }
779
780 pub async fn start(&self) -> anyhow::Result<()> {
781 let own_id = self.routing_table.borrow().own_id;
782 self.search(own_id, false).await
783 }
784}