rustybit_lib/
peer.rs

1use std::collections::VecDeque;
2use std::net::SocketAddrV4;
3use std::sync::Arc;
4use std::time::Duration;
5
6use anyhow::Context;
7use bittorrent_peer_protocol::{BittorrentP2pMessage, Block, BlockRequest, Encode, Handshake};
8use bitvec::order::Msb0;
9use bitvec::vec::BitVec;
10use tokio::io::AsyncWriteExt;
11use tokio::net::TcpStream;
12use tokio::sync::mpsc::{self};
13use tokio::sync::{self, RwLock};
14use tokio::time;
15
16use crate::buffer::ReadBuf;
17use crate::state::event::{PeerEvent, TorrentManagerReq};
18use crate::state::torrent::PieceState;
19use crate::torrent_meta::TorrentMeta;
20use crate::util::piece_size_from_idx;
21use crate::{Elapsed, TorrentSharedState, WithTimeout, DEFAULT_BLOCK_SIZE};
22
23#[tracing::instrument(level = "error", err(level = tracing::Level::DEBUG), skip_all, fields(%peer_addr))]
24pub async fn handle_peer(
25    peer_addr: SocketAddrV4,
26    metadata: TorrentMeta,
27    client_peer_id: [u8; 20],
28    state: Arc<RwLock<TorrentSharedState>>,
29    tx: mpsc::UnboundedSender<(SocketAddrV4, PeerEvent)>,
30    new_peer_req_tx: mpsc::UnboundedSender<(SocketAddrV4, mpsc::Sender<TorrentManagerReq>)>,
31) -> anyhow::Result<()> {
32    let mut stream = TcpStream::connect(peer_addr)
33        .with_timeout("peer connect", Duration::from_secs(5))
34        .await
35        .context("establishing connection with a peer")?;
36
37    tracing::debug!("connected to a peer");
38
39    // Send the handshake message
40    let handshake_message = Handshake::new(metadata.info_hash, client_peer_id);
41    handshake_message.encode(&mut stream).await?;
42
43    let mut read_buf = ReadBuf::new();
44    let handshake = read_buf
45        .read_handshake(&mut stream)
46        .with_timeout("read_handshake", Duration::from_secs(5))
47        .await
48        .context("reading peer handshake")?;
49
50    if handshake.pstr != Handshake::DEFAULT_PSTR {
51        tracing::debug!("Peer sent a bad PSTR, disconnecting: {}", handshake.pstr);
52        anyhow::bail!("bad handshake pstr")
53    }
54
55    tracing::trace!("read a handshake");
56
57    let mut output = Vec::with_capacity(metadata.piece_size * 2);
58    let mut handler = PeerHandler::new(state, metadata, peer_addr);
59
60    let (manager_req_tx, mut manager_req_rx) = sync::mpsc::channel(10);
61    new_peer_req_tx
62        .send((peer_addr, manager_req_tx))
63        .context("error while registering a new peer with the manager")?;
64
65    tracing::trace!("handshakes done, starting a peer handling task");
66
67    let clonex_tx = tx.clone();
68    let task_handle = tokio::spawn(async move {
69        let mut piece_request_interval = time::interval(Duration::from_secs(1));
70        let mut keep_alive_interval = time::interval(Duration::from_secs(120));
71        // Skip the first tick, as it completes immediately and we just opened a new connection
72        keep_alive_interval.tick().await;
73
74        loop {
75            tokio::select! {
76                message = read_buf.read_message(&mut stream).with_elapsed("read_message", Some(Duration::from_millis(100))) => {
77                    let message = message.context("reading message")?;
78                    if let Some(event) = handler
79                        .handle_message(message, &mut output)
80                        .with_elapsed("handle_message", Some(Duration::from_millis(50)))
81                        .await?
82                    {
83                        clonex_tx.send((peer_addr, event)).context("sending a peer event")?;
84                    };
85                }
86                req = manager_req_rx.recv() => {
87                    match req.context("bug: manager dropped the event-sending half?")? {
88                        TorrentManagerReq::CancelPiece(piece_idx) => {
89                            handler.cancel_block_requests_for_piece(piece_idx, &mut output).await.context("cancelling block requests for stolen piece")?;
90                        }
91                        TorrentManagerReq::Disconnect(reason) => {
92                            tracing::trace!(%reason, "cancellation requested, peer exiting");
93                            break;
94                        }
95                    }
96                }
97                // Try to steal a piece
98                _ = piece_request_interval.tick() => {}
99                _ = keep_alive_interval.tick() => {
100                    // It's time to send a Keep Alive message
101                    handler.send_keep_alive(&mut output).await?;
102                }
103            }
104
105            if !handler.peer_choked
106                && handler.present_pieces.is_some()
107                && handler.block_requests_queue.len() < PeerHandler::MAX_PENDING_BLOCK_REQUESTS
108            {
109                let need_blocks = PeerHandler::MAX_PENDING_BLOCK_REQUESTS - handler.block_requests_queue.len();
110                let number_of_pieces = need_blocks / handler.get_blocks_per_piece() + 1;
111                let next_pieces = handler
112                    .state
113                    .write()
114                    .await
115                    // SAFETY: checked above
116                    .get_next_missing_piece_indexes(
117                        handler.peer_addr,
118                        handler.present_pieces.as_ref().unwrap(),
119                        number_of_pieces,
120                    )
121                    .context("bug: getting next pieces failed?")?;
122                if let Some(piece_indexes) = next_pieces {
123                    // We found new pieces that can be downloaded from this peer
124                    for (index, ..) in piece_indexes.into_iter() {
125                        handler
126                            .add_block_requests_for_piece(index)
127                            .with_context(|| format!("adding block requests for piece: {}", index))?;
128                    }
129                }
130            }
131
132            // Send block requests if we have them and are not choked
133            handler.send_block_requests(&mut output).await?;
134
135            stream.write_all(&output).await.context("writing to the stream")?;
136            output.clear();
137        }
138
139        Ok::<(), anyhow::Error>(())
140    });
141
142    match task_handle.await.context("peer handling task panicked")? {
143        Ok(()) => Ok(()),
144        Err(e) => {
145            tx.send((peer_addr, PeerEvent::Disconnected))
146                .with_context(|| format!("peer graceful shutdown failed. Error that caused shutdown: {:#}", e))?;
147            Err(e)
148        }
149    }
150}
151
152struct PeerHandler {
153    peer_addr: SocketAddrV4,
154    state: Arc<RwLock<TorrentSharedState>>,
155    /// Contains all torrent-related information that a peer handler may need
156    torrent_metadata: TorrentMeta,
157    /// Bitvec with all pieces that a peer has
158    present_pieces: Option<BitVec<u8, Msb0>>,
159    /// Block requests that will be sent when we receive a response to previous ones
160    block_requests_queue: VecDeque<BlockRequest>,
161    /// Block requests that we sent and expect a response
162    sent_block_requests: Vec<BlockRequest>,
163    /// Whether the client is interested in the remote peer
164    client_interested: bool,
165    /// Whether the client chokes the remote peer
166    client_choked: bool,
167    /// Whether the remote peer is interested in the client
168    peer_interested: bool,
169    /// Whether the remote peer choked the client
170    peer_choked: bool,
171}
172
173impl PeerHandler {
174    const MAX_PENDING_BLOCK_REQUESTS: usize = 70;
175
176    pub fn new(state: Arc<RwLock<TorrentSharedState>>, torrent_metadata: TorrentMeta, peer_addr: SocketAddrV4) -> Self {
177        PeerHandler {
178            peer_addr,
179            state,
180            block_requests_queue: VecDeque::with_capacity(Self::MAX_PENDING_BLOCK_REQUESTS),
181            sent_block_requests: Vec::with_capacity(Self::MAX_PENDING_BLOCK_REQUESTS),
182            torrent_metadata,
183            client_choked: true,
184            peer_choked: true,
185            present_pieces: None,
186            client_interested: false,
187            peer_interested: false,
188        }
189    }
190
191    #[tracing::instrument(level = "debug", err, skip_all)]
192    async fn handle_message(
193        &mut self,
194        message: BittorrentP2pMessage,
195        output: &mut Vec<u8>,
196    ) -> anyhow::Result<Option<PeerEvent>> {
197        use BittorrentP2pMessage::*;
198        tracing::trace!(message_id = ?message.message_id(), "handling a message");
199
200        match message {
201            Choke => self.peer_choked = true,
202            Unchoke => self.peer_choked = false,
203            Interested => self.peer_interested = true,
204            NotInterested => self.peer_interested = false,
205            Have(piece_idx) => {
206                let bitvec = self
207                    .present_pieces
208                    .as_mut()
209                    .ok_or_else(|| anyhow::anyhow!("bug: have message received before bitvec"))?;
210
211                bitvec.set(piece_idx as usize, true);
212            }
213            Bitfield(mut bitvec) => {
214                // Remove spare bits
215                let state = self.state.read().await;
216                let n_of_pieces = state.get_number_of_pieces();
217                bitvec.truncate(n_of_pieces);
218                self.present_pieces = Some(bitvec);
219
220                drop(state);
221
222                // We can start asking for pieces now
223                self.send_chocked(false, output).await?;
224                self.send_interested(true, output).await?;
225            }
226            Request(BlockRequest { index, begin, length }) => {
227                tracing::trace!(index, begin, length, "received a Request message from peer");
228            }
229            Piece(Block { index, begin, block }) => {
230                if self.sent_block_requests.is_empty() {
231                    // A block was most likely cancelled before
232                    return Ok(None);
233                }
234
235                let Some(block_request_idx) = self
236                    .sent_block_requests
237                    .iter()
238                    .position(|req| req.index == index && req.begin == begin)
239                else {
240                    // This piece was stolen earlier
241                    self.cancel_block_requests_for_piece(index, output).await?;
242                    return Ok(None);
243                };
244                self.sent_block_requests.remove(block_request_idx);
245
246                let state = self.state.read().await;
247                let piece = state
248                    .get_piece_status(try_into!(index, usize)?)
249                    .context("bug: unexsisting piece index?")?;
250                match piece {
251                    PieceState::Downloading { peer, .. } if *peer == self.peer_addr => {
252                        return Ok(Some(PeerEvent::BlockDownloaded(Block { index, begin, block })));
253                    }
254                    PieceState::Downloading { .. } | PieceState::Downloaded { .. } | PieceState::Verified => {
255                        // Someone stole the piece, ignoring the received block
256                        drop(state);
257                        self.cancel_block_requests_for_piece(index, output).await?;
258                        return Ok(None);
259                    }
260                    _ => anyhow::bail!("bug: someone put a piece that we were downloading back in the queue?"),
261                }
262            }
263            Cancel { index, begin, length } => {
264                tracing::trace!(index, begin, length, "received a Cancel message from peer");
265            }
266            Port(port) => {
267                tracing::trace!(?port, "received a Port message");
268            }
269            KeepAlive => {
270                tracing::debug!("received a KeepAlive message");
271            }
272        };
273
274        Ok(None)
275    }
276
277    async fn send_block_requests(&mut self, output: &mut Vec<u8>) -> anyhow::Result<()> {
278        for _ in 0..(Self::MAX_PENDING_BLOCK_REQUESTS - self.sent_block_requests.len()) {
279            if let Some(request) = self.block_requests_queue.pop_front() {
280                BittorrentP2pMessage::Request(request.clone()).encode(output).await?;
281                self.sent_block_requests.push(request);
282            } else {
283                break;
284            }
285        }
286
287        Ok(())
288    }
289
290    async fn cancel_block_requests_for_piece(&mut self, piece_idx: u32, output: &mut Vec<u8>) -> anyhow::Result<()> {
291        // Remove queued requests for this piece if any
292        self.block_requests_queue.retain(|block| block.index != piece_idx);
293
294        let requests_to_cancel = self
295            .sent_block_requests
296            .iter()
297            .enumerate()
298            .filter(|(_, block)| block.index == piece_idx)
299            .map(|(idx, _)| idx)
300            .collect::<Vec<usize>>();
301
302        for (offset, index) in requests_to_cancel.iter().enumerate() {
303            let BlockRequest { index, begin, length } = self.sent_block_requests.swap_remove(index - offset);
304            BittorrentP2pMessage::Cancel { index, begin, length }
305                .encode(output)
306                .await?;
307        }
308
309        Ok(())
310    }
311
312    async fn send_interested(&mut self, client_interested: bool, output: &mut Vec<u8>) -> anyhow::Result<()> {
313        if client_interested {
314            BittorrentP2pMessage::Interested.encode(output).await?;
315        } else {
316            BittorrentP2pMessage::NotInterested.encode(output).await?;
317        }
318
319        self.client_interested = client_interested;
320
321        Ok(())
322    }
323
324    async fn send_chocked(&mut self, client_choked: bool, output: &mut Vec<u8>) -> anyhow::Result<()> {
325        if client_choked {
326            BittorrentP2pMessage::Choke.encode(output).await?;
327        } else {
328            BittorrentP2pMessage::Unchoke.encode(output).await?;
329        }
330
331        self.client_choked = client_choked;
332
333        Ok(())
334    }
335
336    async fn send_keep_alive(&self, output: &mut Vec<u8>) -> anyhow::Result<()> {
337        BittorrentP2pMessage::KeepAlive.encode(output).await?;
338
339        Ok(())
340    }
341
342    fn add_block_requests_for_piece(&mut self, index: u32) -> anyhow::Result<()> {
343        let mut leftover_piece_size = piece_size_from_idx(
344            self.torrent_metadata.number_of_pieces,
345            self.torrent_metadata.total_length,
346            self.torrent_metadata.piece_size,
347            try_into!(index, usize)?,
348        ) as u32;
349        let mut begin = 0u32;
350        while leftover_piece_size > 0 {
351            let length = DEFAULT_BLOCK_SIZE.min(leftover_piece_size);
352            let request = BlockRequest { index, begin, length };
353            self.block_requests_queue.push_back(request);
354            leftover_piece_size -= length;
355            begin += length;
356        }
357
358        Ok(())
359    }
360
361    fn get_blocks_per_piece(&self) -> usize {
362        let piece_block_size = DEFAULT_BLOCK_SIZE as usize;
363        // Round upwards
364        self.torrent_metadata.piece_size / piece_block_size
365            + (self.torrent_metadata.piece_size % piece_block_size != 0) as usize
366    }
367}