1use std::collections::VecDeque;
2use std::net::SocketAddrV4;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Context;
7use bittorrent_peer_protocol::{BittorrentP2pMessage, Block, BlockRequest, Encode, Handshake};
8use bitvec::order::Msb0;
9use bitvec::vec::BitVec;
10use tokio::io::AsyncWriteExt;
11use tokio::net::TcpStream;
12use tokio::sync::mpsc::{self};
13use tokio::sync::{self, RwLock};
14use tokio::time;
15
16use crate::buffer::ReadBuf;
17use crate::state::event::{PeerEvent, TorrentManagerReq};
18use crate::state::torrent::PieceState;
19use crate::torrent_meta::TorrentMeta;
20use crate::util::piece_size_from_idx;
21use crate::{Elapsed, TorrentSharedState, WithTimeout, DEFAULT_BLOCK_SIZE};
22
23#[tracing::instrument(level = "error", err(level = tracing::Level::DEBUG), skip_all, fields(%peer_addr))]
24pub async fn handle_peer(
25 peer_addr: SocketAddrV4,
26 metadata: TorrentMeta,
27 client_peer_id: [u8; 20],
28 state: Arc<RwLock<TorrentSharedState>>,
29 tx: mpsc::UnboundedSender<(SocketAddrV4, PeerEvent)>,
30 new_peer_req_tx: mpsc::UnboundedSender<(SocketAddrV4, mpsc::Sender<TorrentManagerReq>)>,
31) -> anyhow::Result<()> {
32 let mut stream = TcpStream::connect(peer_addr)
33 .with_timeout("peer connect", Duration::from_secs(5))
34 .await
35 .context("establishing connection with a peer")?;
36
37 tracing::debug!("connected to a peer");
38
39 let handshake_message = Handshake::new(metadata.info_hash, client_peer_id);
41 handshake_message.encode(&mut stream).await?;
42
43 let mut read_buf = ReadBuf::new();
44 let handshake = read_buf
45 .read_handshake(&mut stream)
46 .with_timeout("read_handshake", Duration::from_secs(5))
47 .await
48 .context("reading peer handshake")?;
49
50 if handshake.pstr != Handshake::DEFAULT_PSTR {
51 tracing::debug!("Peer sent a bad PSTR, disconnecting: {}", handshake.pstr);
52 anyhow::bail!("bad handshake pstr")
53 }
54
55 tracing::trace!("read a handshake");
56
57 let mut output = Vec::with_capacity(metadata.piece_size * 2);
58 let mut handler = PeerHandler::new(state, metadata, peer_addr);
59
60 let (manager_req_tx, mut manager_req_rx) = sync::mpsc::channel(10);
61 new_peer_req_tx
62 .send((peer_addr, manager_req_tx))
63 .context("error while registering a new peer with the manager")?;
64
65 tracing::trace!("handshakes done, starting a peer handling task");
66
67 let clonex_tx = tx.clone();
68 let task_handle = tokio::spawn(async move {
69 let mut piece_request_interval = time::interval(Duration::from_secs(1));
70 let mut keep_alive_interval = time::interval(Duration::from_secs(120));
71 keep_alive_interval.tick().await;
73
74 loop {
75 tokio::select! {
76 message = read_buf.read_message(&mut stream).with_elapsed("read_message", Some(Duration::from_millis(100))) => {
77 let message = message.context("reading message")?;
78 if let Some(event) = handler
79 .handle_message(message, &mut output)
80 .with_elapsed("handle_message", Some(Duration::from_millis(50)))
81 .await?
82 {
83 clonex_tx.send((peer_addr, event)).context("sending a peer event")?;
84 };
85 }
86 req = manager_req_rx.recv() => {
87 match req.context("bug: manager dropped the event-sending half?")? {
88 TorrentManagerReq::CancelPiece(piece_idx) => {
89 handler.cancel_block_requests_for_piece(piece_idx, &mut output).await.context("cancelling block requests for stolen piece")?;
90 }
91 TorrentManagerReq::Disconnect(reason) => {
92 tracing::trace!(%reason, "cancellation requested, peer exiting");
93 break;
94 }
95 }
96 }
97 _ = piece_request_interval.tick() => {}
99 _ = keep_alive_interval.tick() => {
100 handler.send_keep_alive(&mut output).await?;
102 }
103 }
104
105 if !handler.peer_choked
106 && handler.present_pieces.is_some()
107 && handler.block_requests_queue.len() < PeerHandler::MAX_PENDING_BLOCK_REQUESTS
108 {
109 let need_blocks = PeerHandler::MAX_PENDING_BLOCK_REQUESTS - handler.block_requests_queue.len();
110 let number_of_pieces = need_blocks / handler.get_blocks_per_piece() + 1;
111 let next_pieces = handler
112 .state
113 .write()
114 .await
115 .get_next_missing_piece_indexes(
117 handler.peer_addr,
118 handler.present_pieces.as_ref().unwrap(),
119 number_of_pieces,
120 )
121 .context("bug: getting next pieces failed?")?;
122 if let Some(piece_indexes) = next_pieces {
123 for (index, ..) in piece_indexes.into_iter() {
125 handler
126 .add_block_requests_for_piece(index)
127 .with_context(|| format!("adding block requests for piece: {}", index))?;
128 }
129 }
130 }
131
132 handler.send_block_requests(&mut output).await?;
134
135 stream.write_all(&output).await.context("writing to the stream")?;
136 output.clear();
137 }
138
139 Ok::<(), anyhow::Error>(())
140 });
141
142 match task_handle.await.context("peer handling task panicked")? {
143 Ok(()) => Ok(()),
144 Err(e) => {
145 tx.send((peer_addr, PeerEvent::Disconnected))
146 .with_context(|| format!("peer graceful shutdown failed. Error that caused shutdown: {:#}", e))?;
147 Err(e)
148 }
149 }
150}
151
152struct PeerHandler {
153 peer_addr: SocketAddrV4,
154 state: Arc<RwLock<TorrentSharedState>>,
155 torrent_metadata: TorrentMeta,
157 present_pieces: Option<BitVec<u8, Msb0>>,
159 block_requests_queue: VecDeque<BlockRequest>,
161 sent_block_requests: Vec<BlockRequest>,
163 client_interested: bool,
165 client_choked: bool,
167 peer_interested: bool,
169 peer_choked: bool,
171}
172
173impl PeerHandler {
174 const MAX_PENDING_BLOCK_REQUESTS: usize = 70;
175
176 pub fn new(state: Arc<RwLock<TorrentSharedState>>, torrent_metadata: TorrentMeta, peer_addr: SocketAddrV4) -> Self {
177 PeerHandler {
178 peer_addr,
179 state,
180 block_requests_queue: VecDeque::with_capacity(Self::MAX_PENDING_BLOCK_REQUESTS),
181 sent_block_requests: Vec::with_capacity(Self::MAX_PENDING_BLOCK_REQUESTS),
182 torrent_metadata,
183 client_choked: true,
184 peer_choked: true,
185 present_pieces: None,
186 client_interested: false,
187 peer_interested: false,
188 }
189 }
190
191 #[tracing::instrument(level = "debug", err, skip_all)]
192 async fn handle_message(
193 &mut self,
194 message: BittorrentP2pMessage,
195 output: &mut Vec<u8>,
196 ) -> anyhow::Result<Option<PeerEvent>> {
197 use BittorrentP2pMessage::*;
198 tracing::trace!(message_id = ?message.message_id(), "handling a message");
199
200 match message {
201 Choke => self.peer_choked = true,
202 Unchoke => self.peer_choked = false,
203 Interested => self.peer_interested = true,
204 NotInterested => self.peer_interested = false,
205 Have(piece_idx) => {
206 let bitvec = self
207 .present_pieces
208 .as_mut()
209 .ok_or_else(|| anyhow::anyhow!("bug: have message received before bitvec"))?;
210
211 bitvec.set(piece_idx as usize, true);
212 }
213 Bitfield(mut bitvec) => {
214 let state = self.state.read().await;
216 let n_of_pieces = state.get_number_of_pieces();
217 bitvec.truncate(n_of_pieces);
218 self.present_pieces = Some(bitvec);
219
220 drop(state);
221
222 self.send_chocked(false, output).await?;
224 self.send_interested(true, output).await?;
225 }
226 Request(BlockRequest { index, begin, length }) => {
227 tracing::trace!(index, begin, length, "received a Request message from peer");
228 }
229 Piece(Block { index, begin, block }) => {
230 if self.sent_block_requests.is_empty() {
231 return Ok(None);
233 }
234
235 let Some(block_request_idx) = self
236 .sent_block_requests
237 .iter()
238 .position(|req| req.index == index && req.begin == begin)
239 else {
240 self.cancel_block_requests_for_piece(index, output).await?;
242 return Ok(None);
243 };
244 self.sent_block_requests.remove(block_request_idx);
245
246 let state = self.state.read().await;
247 let piece = state
248 .get_piece_status(try_into!(index, usize)?)
249 .context("bug: unexsisting piece index?")?;
250 match piece {
251 PieceState::Downloading { peer, .. } if *peer == self.peer_addr => {
252 return Ok(Some(PeerEvent::BlockDownloaded(Block { index, begin, block })));
253 }
254 PieceState::Downloading { .. } | PieceState::Downloaded { .. } | PieceState::Verified => {
255 drop(state);
257 self.cancel_block_requests_for_piece(index, output).await?;
258 return Ok(None);
259 }
260 _ => anyhow::bail!("bug: someone put a piece that we were downloading back in the queue?"),
261 }
262 }
263 Cancel { index, begin, length } => {
264 tracing::trace!(index, begin, length, "received a Cancel message from peer");
265 }
266 Port(port) => {
267 tracing::trace!(?port, "received a Port message");
268 }
269 KeepAlive => {
270 tracing::debug!("received a KeepAlive message");
271 }
272 };
273
274 Ok(None)
275 }
276
277 async fn send_block_requests(&mut self, output: &mut Vec<u8>) -> anyhow::Result<()> {
278 for _ in 0..(Self::MAX_PENDING_BLOCK_REQUESTS - self.sent_block_requests.len()) {
279 if let Some(request) = self.block_requests_queue.pop_front() {
280 BittorrentP2pMessage::Request(request.clone()).encode(output).await?;
281 self.sent_block_requests.push(request);
282 } else {
283 break;
284 }
285 }
286
287 Ok(())
288 }
289
290 async fn cancel_block_requests_for_piece(&mut self, piece_idx: u32, output: &mut Vec<u8>) -> anyhow::Result<()> {
291 self.block_requests_queue.retain(|block| block.index != piece_idx);
293
294 let requests_to_cancel = self
295 .sent_block_requests
296 .iter()
297 .enumerate()
298 .filter(|(_, block)| block.index == piece_idx)
299 .map(|(idx, _)| idx)
300 .collect::<Vec<usize>>();
301
302 for (offset, index) in requests_to_cancel.iter().enumerate() {
303 let BlockRequest { index, begin, length } = self.sent_block_requests.swap_remove(index - offset);
304 BittorrentP2pMessage::Cancel { index, begin, length }
305 .encode(output)
306 .await?;
307 }
308
309 Ok(())
310 }
311
312 async fn send_interested(&mut self, client_interested: bool, output: &mut Vec<u8>) -> anyhow::Result<()> {
313 if client_interested {
314 BittorrentP2pMessage::Interested.encode(output).await?;
315 } else {
316 BittorrentP2pMessage::NotInterested.encode(output).await?;
317 }
318
319 self.client_interested = client_interested;
320
321 Ok(())
322 }
323
324 async fn send_chocked(&mut self, client_choked: bool, output: &mut Vec<u8>) -> anyhow::Result<()> {
325 if client_choked {
326 BittorrentP2pMessage::Choke.encode(output).await?;
327 } else {
328 BittorrentP2pMessage::Unchoke.encode(output).await?;
329 }
330
331 self.client_choked = client_choked;
332
333 Ok(())
334 }
335
336 async fn send_keep_alive(&self, output: &mut Vec<u8>) -> anyhow::Result<()> {
337 BittorrentP2pMessage::KeepAlive.encode(output).await?;
338
339 Ok(())
340 }
341
342 fn add_block_requests_for_piece(&mut self, index: u32) -> anyhow::Result<()> {
343 let mut leftover_piece_size = piece_size_from_idx(
344 self.torrent_metadata.number_of_pieces,
345 self.torrent_metadata.total_length,
346 self.torrent_metadata.piece_size,
347 try_into!(index, usize)?,
348 ) as u32;
349 let mut begin = 0u32;
350 while leftover_piece_size > 0 {
351 let length = DEFAULT_BLOCK_SIZE.min(leftover_piece_size);
352 let request = BlockRequest { index, begin, length };
353 self.block_requests_queue.push_back(request);
354 leftover_piece_size -= length;
355 begin += length;
356 }
357
358 Ok(())
359 }
360
361 fn get_blocks_per_piece(&self) -> usize {
362 let piece_block_size = DEFAULT_BLOCK_SIZE as usize;
363 self.torrent_metadata.piece_size / piece_block_size
365 + (self.torrent_metadata.piece_size % piece_block_size != 0) as usize
366 }
367}