1use std::collections::{HashMap, VecDeque};
2use std::net::SocketAddrV4;
3use std::sync::atomic::Ordering;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use anyhow::Context;
8use bittorrent_peer_protocol::Block;
9use bitvec::order::Msb0;
10use bitvec::slice::BitSlice;
11use tokio::sync::mpsc::{self, UnboundedReceiver};
12use tokio::sync::RwLock;
13use tokio::time;
14
15use super::event::{PeerEvent, TorrentManagerReq};
16use crate::stats::{DOWNLOADED_BYTES, DOWNLOADED_PIECES, NUMBER_OF_PEERS};
17use crate::storage::StorageOp;
18use crate::torrent_meta::TorrentMeta;
19use crate::util::piece_size_from_idx;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum PieceState {
23 Queued,
24 Downloading { peer: SocketAddrV4, start: Instant },
25 Downloaded,
26 Verified,
27}
28
29#[derive(Debug)]
31pub struct TorrentSharedState {
32 peer_download_stats: HashMap<SocketAddrV4, (f64, f64)>,
33 pieces: Vec<PieceState>,
34 piece_download_progress: HashMap<u32, usize>,
35 cancellation_req_queue: VecDeque<(SocketAddrV4, u32)>,
36}
37
38impl TorrentSharedState {
39 pub fn new(piece_states: Vec<PieceState>, number_of_pieces: usize) -> Self {
40 TorrentSharedState {
41 peer_download_stats: HashMap::new(),
42 pieces: piece_states,
43 piece_download_progress: HashMap::with_capacity(number_of_pieces),
44 cancellation_req_queue: VecDeque::new(),
45 }
46 }
47}
48
49impl TorrentSharedState {
50 fn get_piece_steal_coeff(&self) -> f64 {
51 let total_pieces = self.pieces.len() as f64;
52 let downloaded_pieces = DOWNLOADED_PIECES.load(Ordering::Relaxed) as f64;
53 if downloaded_pieces / total_pieces >= 0.8 {
54 3.
55 } else {
56 10.
57 }
58 }
59
60 pub fn get_next_missing_piece_indexes(
62 &mut self,
63 peer_addr: SocketAddrV4,
64 peer_available_pieces: &BitSlice<u8, Msb0>,
65 number_of_pieces: usize,
66 ) -> anyhow::Result<Option<Vec<(u32, Option<SocketAddrV4>)>>> {
67 let piece_steal_coeff = self.get_piece_steal_coeff();
68 let pieces = self
69 .pieces
70 .iter_mut()
71 .enumerate()
72 .filter_map(|(idx, status)| {
73 if !peer_available_pieces[idx] {
74 return None;
75 }
76 match status {
77 PieceState::Queued => {
78 *status = PieceState::Downloading {
79 peer: peer_addr,
80 start: Instant::now(),
81 };
82 Some(try_into!(idx, u32).map(|idx| (idx, None)))
83 }
84 PieceState::Downloading { peer, start } => {
85 let peer = *peer;
86 if peer == peer_addr {
87 return None;
88 }
89
90 let requesting_peer_stats = self.peer_download_stats.entry(peer_addr).or_insert((0., 0.));
91 if requesting_peer_stats.1 == 0. {
92 return None;
93 }
94 let requesting_peer_avg_time = requesting_peer_stats.0 / requesting_peer_stats.1;
95
96 let elapsed_secs = start.elapsed().as_secs_f64();
97 if elapsed_secs > requesting_peer_avg_time * piece_steal_coeff {
99 tracing::debug!(
100 %peer_addr,
101 stolen_from=%peer,
102 piece=%idx,
103 "stole a piece: elapsed time {}, my avg piece time: {}",
104 elapsed_secs,
105 requesting_peer_avg_time
106 );
107
108 let (ref mut cur_peer_piece_download_times_sum, _) =
110 self.peer_download_stats.entry(peer).or_insert((0., 0.));
111 *cur_peer_piece_download_times_sum += elapsed_secs;
112
113 *status = PieceState::Downloading {
114 peer: peer_addr,
115 start: Instant::now(),
116 };
117 Some(try_into!(idx, u32).map(|idx| (idx, Some(peer))))
118 } else {
119 None
120 }
121 }
122 _ => None,
123 }
124 })
125 .take(number_of_pieces)
126 .collect::<anyhow::Result<Vec<(u32, Option<SocketAddrV4>)>>>()
127 .context("bug: converting piece index to u32 failed - too many pieces?")
128 .map(|vec| if vec.is_empty() { None } else { Some(vec) })?;
129
130 if let Some(pieces) = pieces.as_ref() {
131 for (piece_idx, stolen_from) in pieces.iter().filter(|(_, stolen_from)| stolen_from.is_some()) {
132 self.cancellation_req_queue
133 .push_back((stolen_from.unwrap(), *piece_idx));
134
135 self.piece_download_progress.insert(*piece_idx, 0);
137 }
138 }
139
140 Ok(pieces)
141 }
142
143 pub fn get_piece_status(&self, idx: usize) -> Option<&PieceState> {
144 self.pieces.get(idx)
145 }
146
147 pub fn finished_downloading(&self) -> bool {
149 self.pieces.iter().all(|state| state == &PieceState::Verified)
150 }
151
152 pub fn get_number_of_pieces(&self) -> usize {
154 self.pieces.len()
155 }
156
157 pub fn mark_piece_as_verified(&mut self, piece_idx: usize) {
158 if let Some(piece_state) = self.pieces.get_mut(piece_idx) {
159 *piece_state = PieceState::Verified
160 };
161 }
162
163 pub fn on_peer_disconnect(&mut self, dead_peer_addr: &SocketAddrV4) {
164 self.pieces.iter_mut().for_each(|piece| match piece {
165 PieceState::Downloading { peer, .. } if peer == dead_peer_addr => {
166 *piece = PieceState::Queued;
167 }
168 _ => {}
169 })
170 }
171}
172
173#[derive(Debug)]
174pub struct Torrent {
175 torrent_meta: TorrentMeta,
176 shared_state: Arc<RwLock<TorrentSharedState>>,
177 piece_hashes: Vec<[u8; 20]>,
178 peer_req_txs: HashMap<SocketAddrV4, mpsc::Sender<TorrentManagerReq>>,
180 rx: UnboundedReceiver<(SocketAddrV4, PeerEvent)>,
182 new_peer_req_rx: UnboundedReceiver<(SocketAddrV4, mpsc::Sender<TorrentManagerReq>)>,
184 storage_tx: mpsc::Sender<StorageOp>,
186 storage_rx: mpsc::Receiver<(SocketAddrV4, u32, bool)>,
188}
189
190impl Torrent {
191 pub fn new(
192 torrent_meta: TorrentMeta,
193 state: Arc<RwLock<TorrentSharedState>>,
194 piece_hashes: Vec<[u8; 20]>,
195 rx: UnboundedReceiver<(SocketAddrV4, PeerEvent)>,
196 new_peer_req_rx: UnboundedReceiver<(SocketAddrV4, mpsc::Sender<TorrentManagerReq>)>,
197 storage_tx: mpsc::Sender<StorageOp>,
198 storage_rx: mpsc::Receiver<(SocketAddrV4, u32, bool)>,
199 ) -> Self {
200 Torrent {
201 torrent_meta,
202 shared_state: state,
203 peer_req_txs: HashMap::new(),
204 piece_hashes,
205 rx,
206 new_peer_req_rx,
207 storage_tx,
208 storage_rx,
209 }
210 }
211
212 #[tracing::instrument(level = "debug", err, skip(self))]
213 pub async fn handle(&mut self) -> anyhow::Result<()> {
214 let mut interval = time::interval(Duration::from_secs(1));
215 loop {
216 tokio::select! {
217 result = self.rx.recv() => {
218 if let Some((peer_addr, event)) = result {
219 match event {
220 PeerEvent::BlockDownloaded(block) => {
221 let mut state = self.shared_state.write().await;
224 if Torrent::verify_piece_not_stolen(&state, peer_addr, try_into!(block.index, usize)?).await? {
225 if !self.add_block(&mut state, peer_addr, block).await? {
226 drop(state);
227 self.disconnect_peer(&peer_addr, "error while adding a block", true).await?;
228 };
229 }
230 }
231 PeerEvent::Disconnected => {
232 tracing::debug!(%peer_addr, "peer exited unexpectedly");
233 self.remove_peer_req_tx(&peer_addr);
235 self.disconnect_peer(&peer_addr, "peer exited", false).await?;
236
237 if self.peer_req_txs.is_empty() && !self.shared_state.read().await.finished_downloading() {
238 anyhow::bail!("All peers exited before finishing the torrent, the download is incomplete");
239 }
240 }
241 };
242 } else {
243 anyhow::bail!("bug: all peers exited unexepectedly?");
246 }
247 },
248 result = self.storage_rx.recv() => {
249 let Some((peer_addr, piece_idx, is_correct)) = result else {
250 anyhow::bail!("bug: storage backend exited before torrent manager?");
251 };
252
253 if !is_correct {
254 tracing::debug!(
255 %peer_addr,
256 %piece_idx,
257 "piece hash verification failed: disconnecting the peer"
258 );
259
260 self.disconnect_peer(&peer_addr, "piece hash verification failed", true).await?;
261 } else {
262 let mut shared_state = self.shared_state.write().await;
263 let piece_state = shared_state.pieces.get_mut(try_into!(piece_idx, usize)?).context("bug: downloaded a ghost piece?")?;
264
265 *piece_state = PieceState::Verified;
266
267 let downloaded_bytes = shared_state.piece_download_progress.get(&piece_idx).context("bug: downloaded a piece but didn't track its bytes?")?;
268 DOWNLOADED_BYTES.fetch_add(*downloaded_bytes, Ordering::Relaxed);
269 DOWNLOADED_PIECES.fetch_add(1, Ordering::Relaxed);
270
271 if shared_state.finished_downloading() {
272 tracing::info!("Successfully finished downloading the torrent");
273 tracing::debug!("shutting down peers");
274 for (peer_addr, req_tx) in self.peer_req_txs.drain() {
275 if req_tx.send(TorrentManagerReq::Disconnect("finished downloading")).await.is_err() {
276 tracing::debug!(
277 %peer_addr,
278 "error while shutting down a peer: peer already dropped the receiver"
279 );
280 }
281 };
282 break;
283 }
284 }
285 },
286 result = self.new_peer_req_rx.recv() => {
287 if let Some((peer_addr, peer_tx_channel)) = result {
288 NUMBER_OF_PEERS.fetch_add(1, Ordering::Relaxed);
289 self.peer_req_txs.insert(peer_addr, peer_tx_channel);
290 };
291 }
292 _ = interval.tick() => {}
293 }
294 let mut state = self.shared_state.write().await;
295 while let Some((peer_addr, piece_idx)) = state.cancellation_req_queue.pop_front() {
296 if let Some(sender) = self.get_peer_req_tx(&peer_addr) {
297 if let Err(e) = sender.send(TorrentManagerReq::CancelPiece(piece_idx)).await {
298 tracing::debug!(
299 %peer_addr,
300 "error while sending a cancellation request: {}", e
301 );
302 }
303 }
304 }
305 }
306
307 Ok(())
308 }
309
310 #[tracing::instrument(err, skip(self, state, begin, block))]
311 async fn add_block(
312 &self,
313 state: &mut TorrentSharedState,
314 peer_addr: SocketAddrV4,
315 Block { index, begin, block }: Block,
316 ) -> anyhow::Result<bool> {
317 let downloaded_bytes = state.piece_download_progress.entry(index).or_insert(0);
318
319 let expected_piece_size = piece_size_from_idx(
320 self.torrent_meta.number_of_pieces,
321 self.torrent_meta.total_length,
322 self.torrent_meta.piece_size,
323 try_into!(index, usize)?,
324 );
325
326 if *downloaded_bytes + block.len() > expected_piece_size {
327 tracing::debug!(
328 "piece is larger than expected: {} vs {}",
329 *downloaded_bytes + block.len(),
330 expected_piece_size,
331 );
332 return Ok(false);
333 }
334
335 *downloaded_bytes += block.len();
336
337 self.storage_tx
338 .send(StorageOp::AddBlock(Block { index, begin, block }))
339 .await
340 .with_context(|| {
341 format!(
342 "Failed to send a block to the storage backend: index {}, in-piece offset: {}",
343 index, begin
344 )
345 })?;
346
347 if *downloaded_bytes == expected_piece_size {
348 let piece_idx = try_into!(index, usize)?;
349 let Some(expected_piece_hash) = self.get_piece_hash(piece_idx) else {
350 tracing::debug!(
351 "Wrong piece index: index {}, total pieces: {}",
352 piece_idx,
353 self.shared_state.read().await.get_number_of_pieces()
354 );
355 return Ok(false);
356 };
357
358 let piece_state = state
359 .pieces
360 .get_mut(piece_idx)
361 .with_context(|| format!("bug: missing piece state for piece #{}", piece_idx))?;
362
363 let PieceState::Downloading { start, .. } = piece_state else {
364 anyhow::bail!("bug: how did we even get here?");
365 };
366
367 let elapsed_secs = start.elapsed().as_secs_f64();
369 let (ref mut peer_piece_download_times_sum, ref mut peer_downloaded_pieces) =
370 state.peer_download_stats.entry(peer_addr).or_insert((0., 0.));
371 *peer_piece_download_times_sum += elapsed_secs;
372 *peer_downloaded_pieces += 1.;
373
374 *piece_state = PieceState::Downloaded;
376
377 self.storage_tx
378 .send(StorageOp::CheckPieceHash((
379 peer_addr.to_owned(),
380 index,
381 expected_piece_hash.to_owned(),
382 )))
383 .await
384 .with_context(|| {
385 format!(
386 "Failed to send a 'check piece hash' request to the storage backend: index {}",
387 index
388 )
389 })?;
390 }
391
392 Ok(true)
393 }
394
395 fn get_peer_req_tx(&self, peer_addr: &SocketAddrV4) -> Option<&mpsc::Sender<TorrentManagerReq>> {
396 self.peer_req_txs.get(peer_addr)
397 }
398
399 fn remove_peer_req_tx(&mut self, peer_addr: &SocketAddrV4) -> Option<mpsc::Sender<TorrentManagerReq>> {
400 self.peer_req_txs.remove(peer_addr)
401 }
402
403 fn get_piece_hash(&self, index: usize) -> Option<&[u8; 20]> {
404 self.piece_hashes.get(index)
405 }
406
407 async fn verify_piece_not_stolen(
408 state: &TorrentSharedState,
409 peer_addr: SocketAddrV4,
410 piece_idx: usize,
411 ) -> anyhow::Result<bool> {
412 let piece_status = state.get_piece_status(piece_idx).context("bug: bad piece index?")?;
413 match piece_status {
414 PieceState::Downloading { peer, .. } if *peer == peer_addr => Ok(true),
415 _ => Ok(false),
416 }
417 }
418
419 async fn disconnect_peer(
420 &mut self,
421 peer_addr: &SocketAddrV4,
422 disconnect_reason: &'static str,
423 send_disconnect: bool,
424 ) -> anyhow::Result<()> {
425 if send_disconnect {
426 if let Some(req_tx) = self.remove_peer_req_tx(&peer_addr) {
427 if req_tx
428 .send(TorrentManagerReq::Disconnect(disconnect_reason))
429 .await
430 .is_err()
431 {
432 tracing::debug!(
433 %peer_addr,
434 "error while shutting down a peer: peer already dropped the receiver"
435 );
436 }
437 };
438 }
439
440 let mut state = self.shared_state.write().await;
441 let reset_pieces = state
442 .pieces
443 .iter_mut()
444 .enumerate()
445 .filter_map(|(idx, piece_state)| match piece_state {
446 PieceState::Downloading { peer, .. } if peer == peer_addr => {
447 *piece_state = PieceState::Queued;
448 Some(idx)
449 }
450 _ => None,
451 })
452 .collect::<Vec<usize>>();
453 for piece in reset_pieces.into_iter() {
454 let piece_idx = try_into!(piece, u32)?;
455 state.piece_download_progress.insert(piece_idx, 0);
456 }
457
458 NUMBER_OF_PEERS.fetch_sub(1, Ordering::Relaxed);
459
460 Ok(())
461 }
462}