rustybit_lib/state/
torrent.rs

1use std::collections::{HashMap, VecDeque};
2use std::net::SocketAddrV4;
3use std::sync::atomic::Ordering;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use anyhow::Context;
8use bittorrent_peer_protocol::Block;
9use bitvec::order::Msb0;
10use bitvec::slice::BitSlice;
11use tokio::sync::mpsc::{self, UnboundedReceiver};
12use tokio::sync::RwLock;
13use tokio::time;
14
15use super::event::{PeerEvent, TorrentManagerReq};
16use crate::stats::{DOWNLOADED_BYTES, DOWNLOADED_PIECES, NUMBER_OF_PEERS};
17use crate::storage::StorageOp;
18use crate::torrent_meta::TorrentMeta;
19use crate::util::piece_size_from_idx;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PieceState {
23    Queued,
24    Downloading { peer: SocketAddrV4, start: Instant },
25    Downloaded,
26    Verified,
27}
28
29/// State that each peer keeps a reference to
30#[derive(Debug)]
31pub struct TorrentSharedState {
32    peer_download_stats: HashMap<SocketAddrV4, (f64, f64)>,
33    pieces: Vec<PieceState>,
34    piece_download_progress: HashMap<u32, usize>,
35    cancellation_req_queue: VecDeque<(SocketAddrV4, u32)>,
36}
37
38impl TorrentSharedState {
39    pub fn new(piece_states: Vec<PieceState>, number_of_pieces: usize) -> Self {
40        TorrentSharedState {
41            peer_download_stats: HashMap::new(),
42            pieces: piece_states,
43            piece_download_progress: HashMap::with_capacity(number_of_pieces),
44            cancellation_req_queue: VecDeque::new(),
45        }
46    }
47}
48
49impl TorrentSharedState {
50    fn get_piece_steal_coeff(&self) -> f64 {
51        let total_pieces = self.pieces.len() as f64;
52        let downloaded_pieces = DOWNLOADED_PIECES.load(Ordering::Relaxed) as f64;
53        if downloaded_pieces / total_pieces >= 0.8 {
54            3.
55        } else {
56            10.
57        }
58    }
59
60    /// Returns either a piece that we failed to download earlier or one that we didn't try yet.
61    pub fn get_next_missing_piece_indexes(
62        &mut self,
63        peer_addr: SocketAddrV4,
64        peer_available_pieces: &BitSlice<u8, Msb0>,
65        number_of_pieces: usize,
66    ) -> anyhow::Result<Option<Vec<(u32, Option<SocketAddrV4>)>>> {
67        let piece_steal_coeff = self.get_piece_steal_coeff();
68        let pieces = self
69            .pieces
70            .iter_mut()
71            .enumerate()
72            .filter_map(|(idx, status)| {
73                if !peer_available_pieces[idx] {
74                    return None;
75                }
76                match status {
77                    PieceState::Queued => {
78                        *status = PieceState::Downloading {
79                            peer: peer_addr,
80                            start: Instant::now(),
81                        };
82                        Some(try_into!(idx, u32).map(|idx| (idx, None)))
83                    }
84                    PieceState::Downloading { peer, start } => {
85                        let peer = *peer;
86                        if peer == peer_addr {
87                            return None;
88                        }
89
90                        let requesting_peer_stats = self.peer_download_stats.entry(peer_addr).or_insert((0., 0.));
91                        if requesting_peer_stats.1 == 0. {
92                            return None;
93                        }
94                        let requesting_peer_avg_time = requesting_peer_stats.0 / requesting_peer_stats.1;
95
96                        let elapsed_secs = start.elapsed().as_secs_f64();
97                        // Compare elapsed time to requesting peer's average piece download time
98                        if elapsed_secs > requesting_peer_avg_time * piece_steal_coeff {
99                            tracing::debug!(
100                                %peer_addr,
101                                stolen_from=%peer,
102                                piece=%idx,
103                                "stole a piece: elapsed time {}, my avg piece time: {}",
104                                elapsed_secs,
105                                requesting_peer_avg_time
106                            );
107
108                            // Update the current peer's download stats
109                            let (ref mut cur_peer_piece_download_times_sum, _) =
110                                self.peer_download_stats.entry(peer).or_insert((0., 0.));
111                            *cur_peer_piece_download_times_sum += elapsed_secs;
112
113                            *status = PieceState::Downloading {
114                                peer: peer_addr,
115                                start: Instant::now(),
116                            };
117                            Some(try_into!(idx, u32).map(|idx| (idx, Some(peer))))
118                        } else {
119                            None
120                        }
121                    }
122                    _ => None,
123                }
124            })
125            .take(number_of_pieces)
126            .collect::<anyhow::Result<Vec<(u32, Option<SocketAddrV4>)>>>()
127            .context("bug: converting piece index to u32 failed - too many pieces?")
128            .map(|vec| if vec.is_empty() { None } else { Some(vec) })?;
129
130        if let Some(pieces) = pieces.as_ref() {
131            for (piece_idx, stolen_from) in pieces.iter().filter(|(_, stolen_from)| stolen_from.is_some()) {
132                self.cancellation_req_queue
133                    .push_back((stolen_from.unwrap(), *piece_idx));
134
135                // Reset piece download progress
136                self.piece_download_progress.insert(*piece_idx, 0);
137            }
138        }
139
140        Ok(pieces)
141    }
142
143    pub fn get_piece_status(&self, idx: usize) -> Option<&PieceState> {
144        self.pieces.get(idx)
145    }
146
147    /// Checks whether the current torrent was fully downloaded
148    pub fn finished_downloading(&self) -> bool {
149        self.pieces.iter().all(|state| state == &PieceState::Verified)
150    }
151
152    /// Returns the total number of pieces for the current torrent
153    pub fn get_number_of_pieces(&self) -> usize {
154        self.pieces.len()
155    }
156
157    pub fn mark_piece_as_verified(&mut self, piece_idx: usize) {
158        if let Some(piece_state) = self.pieces.get_mut(piece_idx) {
159            *piece_state = PieceState::Verified
160        };
161    }
162
163    pub fn on_peer_disconnect(&mut self, dead_peer_addr: &SocketAddrV4) {
164        self.pieces.iter_mut().for_each(|piece| match piece {
165            PieceState::Downloading { peer, .. } if peer == dead_peer_addr => {
166                *piece = PieceState::Queued;
167            }
168            _ => {}
169        })
170    }
171}
172
173#[derive(Debug)]
174pub struct Torrent {
175    torrent_meta: TorrentMeta,
176    shared_state: Arc<RwLock<TorrentSharedState>>,
177    piece_hashes: Vec<[u8; 20]>,
178    /// Channels for sending peer-level events to peers
179    peer_req_txs: HashMap<SocketAddrV4, mpsc::Sender<TorrentManagerReq>>,
180    /// Channel for receiving peer-level events from peers
181    rx: UnboundedReceiver<(SocketAddrV4, PeerEvent)>,
182    /// Channel for receiving request channels for new peers
183    new_peer_req_rx: UnboundedReceiver<(SocketAddrV4, mpsc::Sender<TorrentManagerReq>)>,
184    /// Channel for communicating with the storage backend
185    storage_tx: mpsc::Sender<StorageOp>,
186    /// Channel for receiving the result of checking piece hashes and storage-related errors
187    storage_rx: mpsc::Receiver<(SocketAddrV4, u32, bool)>,
188}
189
190impl Torrent {
191    pub fn new(
192        torrent_meta: TorrentMeta,
193        state: Arc<RwLock<TorrentSharedState>>,
194        piece_hashes: Vec<[u8; 20]>,
195        rx: UnboundedReceiver<(SocketAddrV4, PeerEvent)>,
196        new_peer_req_rx: UnboundedReceiver<(SocketAddrV4, mpsc::Sender<TorrentManagerReq>)>,
197        storage_tx: mpsc::Sender<StorageOp>,
198        storage_rx: mpsc::Receiver<(SocketAddrV4, u32, bool)>,
199    ) -> Self {
200        Torrent {
201            torrent_meta,
202            shared_state: state,
203            peer_req_txs: HashMap::new(),
204            piece_hashes,
205            rx,
206            new_peer_req_rx,
207            storage_tx,
208            storage_rx,
209        }
210    }
211
212    #[tracing::instrument(level = "debug", err, skip(self))]
213    pub async fn handle(&mut self) -> anyhow::Result<()> {
214        let mut interval = time::interval(Duration::from_secs(1));
215        loop {
216            tokio::select! {
217                result = self.rx.recv() => {
218                    if let Some((peer_addr, event)) = result {
219                        match event {
220                            PeerEvent::BlockDownloaded(block) => {
221                                // We need to own the state all this time to avoid some other peer
222                                // stealing the piece
223                                let mut state = self.shared_state.write().await;
224                                if Torrent::verify_piece_not_stolen(&state, peer_addr, try_into!(block.index, usize)?).await? {
225                                    if !self.add_block(&mut state, peer_addr, block).await? {
226                                        drop(state);
227                                        self.disconnect_peer(&peer_addr, "error while adding a block", true).await?;
228                                    };
229                                }
230                            }
231                            PeerEvent::Disconnected => {
232                                tracing::debug!(%peer_addr, "peer exited unexpectedly");
233                                // Drop peer cancellation tx
234                                self.remove_peer_req_tx(&peer_addr);
235                                self.disconnect_peer(&peer_addr, "peer exited", false).await?;
236
237                                if self.peer_req_txs.is_empty() && !self.shared_state.read().await.finished_downloading() {
238                                    anyhow::bail!("All peers exited before finishing the torrent, the download is incomplete");
239                                }
240                            }
241                        };
242                    } else {
243                        // This can only happen if all peers panicked somehow and didn't disconnect
244                        // properly
245                        anyhow::bail!("bug: all peers exited unexepectedly?");
246                    }
247                },
248                result = self.storage_rx.recv() => {
249                    let Some((peer_addr, piece_idx, is_correct)) = result else {
250                        anyhow::bail!("bug: storage backend exited before torrent manager?");
251                    };
252
253                    if !is_correct {
254                        tracing::debug!(
255                            %peer_addr,
256                            %piece_idx,
257                            "piece hash verification failed: disconnecting the peer"
258                        );
259
260                        self.disconnect_peer(&peer_addr, "piece hash verification failed", true).await?;
261                    } else {
262                        let mut shared_state = self.shared_state.write().await;
263                        let piece_state = shared_state.pieces.get_mut(try_into!(piece_idx, usize)?).context("bug: downloaded a ghost piece?")?;
264
265                        *piece_state = PieceState::Verified;
266
267                        let downloaded_bytes = shared_state.piece_download_progress.get(&piece_idx).context("bug: downloaded a piece but didn't track its bytes?")?;
268                        DOWNLOADED_BYTES.fetch_add(*downloaded_bytes, Ordering::Relaxed);
269                        DOWNLOADED_PIECES.fetch_add(1, Ordering::Relaxed);
270
271                        if shared_state.finished_downloading() {
272                            tracing::info!("Successfully finished downloading the torrent");
273                            tracing::debug!("shutting down peers");
274                            for (peer_addr, req_tx) in self.peer_req_txs.drain() {
275                                if req_tx.send(TorrentManagerReq::Disconnect("finished downloading")).await.is_err() {
276                                    tracing::debug!(
277                                        %peer_addr,
278                                        "error while shutting down a peer: peer already dropped the receiver"
279                                    );
280                                }
281                            };
282                            break;
283                        }
284                    }
285                },
286                result = self.new_peer_req_rx.recv() => {
287                    if let Some((peer_addr, peer_tx_channel)) = result {
288                        NUMBER_OF_PEERS.fetch_add(1, Ordering::Relaxed);
289                        self.peer_req_txs.insert(peer_addr, peer_tx_channel);
290                    };
291                }
292                _ = interval.tick() => {}
293            }
294            let mut state = self.shared_state.write().await;
295            while let Some((peer_addr, piece_idx)) = state.cancellation_req_queue.pop_front() {
296                if let Some(sender) = self.get_peer_req_tx(&peer_addr) {
297                    if let Err(e) = sender.send(TorrentManagerReq::CancelPiece(piece_idx)).await {
298                        tracing::debug!(
299                            %peer_addr,
300                            "error while sending a cancellation request: {}", e
301                        );
302                    }
303                }
304            }
305        }
306
307        Ok(())
308    }
309
310    #[tracing::instrument(err, skip(self, state, begin, block))]
311    async fn add_block(
312        &self,
313        state: &mut TorrentSharedState,
314        peer_addr: SocketAddrV4,
315        Block { index, begin, block }: Block,
316    ) -> anyhow::Result<bool> {
317        let downloaded_bytes = state.piece_download_progress.entry(index).or_insert(0);
318
319        let expected_piece_size = piece_size_from_idx(
320            self.torrent_meta.number_of_pieces,
321            self.torrent_meta.total_length,
322            self.torrent_meta.piece_size,
323            try_into!(index, usize)?,
324        );
325
326        if *downloaded_bytes + block.len() > expected_piece_size {
327            tracing::debug!(
328                "piece is larger than expected: {} vs {}",
329                *downloaded_bytes + block.len(),
330                expected_piece_size,
331            );
332            return Ok(false);
333        }
334
335        *downloaded_bytes += block.len();
336
337        self.storage_tx
338            .send(StorageOp::AddBlock(Block { index, begin, block }))
339            .await
340            .with_context(|| {
341                format!(
342                    "Failed to send a block to the storage backend: index {}, in-piece offset: {}",
343                    index, begin
344                )
345            })?;
346
347        if *downloaded_bytes == expected_piece_size {
348            let piece_idx = try_into!(index, usize)?;
349            let Some(expected_piece_hash) = self.get_piece_hash(piece_idx) else {
350                tracing::debug!(
351                    "Wrong piece index: index {}, total pieces: {}",
352                    piece_idx,
353                    self.shared_state.read().await.get_number_of_pieces()
354                );
355                return Ok(false);
356            };
357
358            let piece_state = state
359                .pieces
360                .get_mut(piece_idx)
361                .with_context(|| format!("bug: missing piece state for piece #{}", piece_idx))?;
362
363            let PieceState::Downloading { start, .. } = piece_state else {
364                anyhow::bail!("bug: how did we even get here?");
365            };
366
367            // Update peer download stats
368            let elapsed_secs = start.elapsed().as_secs_f64();
369            let (ref mut peer_piece_download_times_sum, ref mut peer_downloaded_pieces) =
370                state.peer_download_stats.entry(peer_addr).or_insert((0., 0.));
371            *peer_piece_download_times_sum += elapsed_secs;
372            *peer_downloaded_pieces += 1.;
373
374            // Mark piece as downloaded
375            *piece_state = PieceState::Downloaded;
376
377            self.storage_tx
378                .send(StorageOp::CheckPieceHash((
379                    peer_addr.to_owned(),
380                    index,
381                    expected_piece_hash.to_owned(),
382                )))
383                .await
384                .with_context(|| {
385                    format!(
386                        "Failed to send a 'check piece hash' request to the storage backend: index {}",
387                        index
388                    )
389                })?;
390        }
391
392        Ok(true)
393    }
394
395    fn get_peer_req_tx(&self, peer_addr: &SocketAddrV4) -> Option<&mpsc::Sender<TorrentManagerReq>> {
396        self.peer_req_txs.get(peer_addr)
397    }
398
399    fn remove_peer_req_tx(&mut self, peer_addr: &SocketAddrV4) -> Option<mpsc::Sender<TorrentManagerReq>> {
400        self.peer_req_txs.remove(peer_addr)
401    }
402
403    fn get_piece_hash(&self, index: usize) -> Option<&[u8; 20]> {
404        self.piece_hashes.get(index)
405    }
406
407    async fn verify_piece_not_stolen(
408        state: &TorrentSharedState,
409        peer_addr: SocketAddrV4,
410        piece_idx: usize,
411    ) -> anyhow::Result<bool> {
412        let piece_status = state.get_piece_status(piece_idx).context("bug: bad piece index?")?;
413        match piece_status {
414            PieceState::Downloading { peer, .. } if *peer == peer_addr => Ok(true),
415            _ => Ok(false),
416        }
417    }
418
419    async fn disconnect_peer(
420        &mut self,
421        peer_addr: &SocketAddrV4,
422        disconnect_reason: &'static str,
423        send_disconnect: bool,
424    ) -> anyhow::Result<()> {
425        if send_disconnect {
426            if let Some(req_tx) = self.remove_peer_req_tx(&peer_addr) {
427                if req_tx
428                    .send(TorrentManagerReq::Disconnect(disconnect_reason))
429                    .await
430                    .is_err()
431                {
432                    tracing::debug!(
433                        %peer_addr,
434                        "error while shutting down a peer: peer already dropped the receiver"
435                    );
436                }
437            };
438        }
439
440        let mut state = self.shared_state.write().await;
441        let reset_pieces = state
442            .pieces
443            .iter_mut()
444            .enumerate()
445            .filter_map(|(idx, piece_state)| match piece_state {
446                PieceState::Downloading { peer, .. } if peer == peer_addr => {
447                    *piece_state = PieceState::Queued;
448                    Some(idx)
449                }
450                _ => None,
451            })
452            .collect::<Vec<usize>>();
453        for piece in reset_pieces.into_iter() {
454            let piece_idx = try_into!(piece, u32)?;
455            state.piece_download_progress.insert(piece_idx, 0);
456        }
457
458        NUMBER_OF_PEERS.fetch_sub(1, Ordering::Relaxed);
459
460        Ok(())
461    }
462}