Skip to main content

vortex_bittorrent/peer_comm/
extended_protocol.rs

1use std::collections::BTreeMap;
2
3use bitvec::{boxed::BitBox, vec::BitVec};
4use bt_bencode::{ByteString, Deserializer, Value};
5use bytes::{Buf, Bytes};
6use serde::{Deserialize, Serialize};
7use sha1::{Digest, Sha1};
8
9use crate::{
10    peer_comm::peer_connection::PeerConnection,
11    piece_selector::SUBPIECE_SIZE,
12    torrent::{Config, StateRef},
13};
14
15use super::{peer_connection::DisconnectReason, peer_protocol::PeerMessage};
16
17pub const UT_METADATA: &str = "ut_metadata";
18pub const UPLOAD_ONLY: &str = "upload_only";
19
20pub fn init_extension<'state>(
21    id: u8,
22    name: &str,
23    handshake_dict: &BTreeMap<ByteString, Value>,
24    state_ref: &mut StateRef<'state>,
25    outgoing_msgs_buffer: &mut Vec<PeerMessage>,
26) -> Result<Option<Box<dyn ExtensionProtocol>>, DisconnectReason> {
27    match name {
28        UT_METADATA => {
29            let Some(metadata_size) = handshake_dict
30                .get("metadata_size".as_bytes())
31                .and_then(|val| val.as_u64())
32            else {
33                return Err(DisconnectReason::ProtocolError("metadata size not valid"));
34            };
35            if let Some(metadata) = state_ref.metadata() {
36                let expected_size = metadata.construct_info().encode().len();
37                if metadata_size != expected_size as u64 {
38                    return Err(DisconnectReason::ProtocolError("metadata size not valid"));
39                }
40            }
41
42            let mut metadata = MetadataExtension::new(id, metadata_size as usize);
43            if !state_ref.is_initialzied() {
44                for i in 0..8.min(metadata.num_pieces()) {
45                    outgoing_msgs_buffer.push(metadata.request(i as i32));
46                }
47            }
48            Ok(Some(Box::new(metadata)))
49        }
50        UPLOAD_ONLY => Ok(Some(Box::new(UploadOnlyExtension::new(id)))),
51        _ => Ok(None),
52    }
53}
54
55// Supported extensions and this clients ID for them
56pub const EXTENSIONS: [(&str, u8); 2] = [(UT_METADATA, 1), (UPLOAD_ONLY, 2)];
57
58/// The handshake message this peer should send to anyone supporting the
59/// extension
60pub fn extension_handshake_msg(state_ref: &mut StateRef, config: &Config) -> PeerMessage {
61    let mut handshake = BTreeMap::new();
62    let extensions: BTreeMap<_, _> = BTreeMap::from(EXTENSIONS);
63    handshake.insert("m", bt_bencode::value::to_value(&extensions).unwrap());
64    handshake.insert(
65        "v",
66        bt_bencode::value::to_value(&format!("Vortex {}", env!("CARGO_PKG_VERSION"))).unwrap(),
67    );
68    if let Some(listener_port) = state_ref.listener_port {
69        handshake.insert("p", bt_bencode::value::to_value(listener_port).unwrap());
70    }
71    let is_complete = state_ref.state().is_some_and(|state| state.is_complete);
72    if let Some(metadata) = state_ref.metadata() {
73        let metadata_size = metadata.construct_info().encode().len();
74        handshake.insert(
75            "metadata_size",
76            bt_bencode::to_value(&metadata_size).unwrap(),
77        );
78        let upload_only = if is_complete { 1 } else { 0 };
79        handshake.insert(
80            UPLOAD_ONLY,
81            bt_bencode::value::to_value(&upload_only).unwrap(),
82        );
83    }
84    handshake.insert(
85        "reqq",
86        bt_bencode::value::to_value(&config.max_reported_outstanding_requests).unwrap(),
87    );
88
89    PeerMessage::Extended {
90        id: 0,
91        data: bt_bencode::to_vec(&handshake).unwrap().into(),
92    }
93}
94
95pub trait ExtensionProtocol {
96    fn handle_message<'state>(
97        &mut self,
98        data: Bytes,
99        state: &mut StateRef<'state>,
100        // Connection that received the message.
101        // Note you can't modify the extensions map
102        // from inside an extension handler
103        connection: &mut PeerConnection,
104    ) -> Result<(), DisconnectReason>;
105
106    fn on_torrent_complete(&mut self, _outgoing_msgs_buffer: &mut Vec<PeerMessage>) {}
107}
108
109/// An extension protocol supported by libtorrent
110/// to indicate that the peer is a pure seeder
111#[derive(Debug, Deserialize, Serialize)]
112pub struct UploadOnlyExtension {
113    // The peers ID for the extension
114    pub id: u8,
115    pub enabled: bool,
116}
117
118impl UploadOnlyExtension {
119    pub fn new(id: u8) -> Self {
120        Self { id, enabled: false }
121    }
122
123    pub fn upload_only(&mut self, upload_only: bool) -> PeerMessage {
124        let enabled = if upload_only { 1 } else { 0 };
125        PeerMessage::Extended {
126            id: self.id,
127            data: vec![enabled].into(),
128        }
129    }
130}
131
132impl ExtensionProtocol for UploadOnlyExtension {
133    fn handle_message<'state>(
134        &mut self,
135        mut data: Bytes,
136        _state: &mut StateRef<'state>,
137        connection: &mut PeerConnection,
138    ) -> Result<(), DisconnectReason> {
139        let enabled = data
140            .try_get_u8()
141            .map_err(|_err| DisconnectReason::InvalidMessage)?;
142        connection.is_upload_only = enabled > 0;
143        Ok(())
144    }
145
146    fn on_torrent_complete(&mut self, outgoing_msgs_buffer: &mut Vec<PeerMessage>) {
147        outgoing_msgs_buffer.push(self.upload_only(true));
148    }
149}
150
151const REQUEST: u8 = 0;
152const DATA: u8 = 1;
153const REJECT: u8 = 2;
154
155#[derive(Debug, Deserialize, Serialize, PartialEq)]
156pub struct MetadataMessage {
157    pub msg_type: u8,
158    pub piece: i32,
159    #[serde(skip_serializing_if = "Option::is_none")]
160    pub total_size: Option<i32>,
161}
162
163/// Progress in downloading metadata per peer
164#[derive(Debug, Clone, Copy, PartialEq, Default)]
165pub struct MetadataProgress {
166    /// Total number of metadata pieces to download
167    pub total_piece: usize,
168    /// Completed number of metadata pieces
169    pub completed_pieces: usize,
170}
171
172// BEP 9
173pub struct MetadataExtension {
174    // The peers ID for the extension
175    id: u8,
176    metadata: Vec<u8>,
177    // These are only kept up to date
178    // if we are activly downloading the metadata.
179    // If the state is completed there is no guarantee
180    // that these are accurate
181    inflight: BitBox,
182    completed: BitBox,
183}
184
185impl MetadataExtension {
186    pub fn new(id: u8, metadata_size: usize) -> Self {
187        let num_pieces = metadata_size.div_ceil(SUBPIECE_SIZE as usize);
188        let inflight: BitBox = BitVec::repeat(false, num_pieces).into();
189        let completed: BitBox = BitVec::repeat(false, num_pieces).into();
190        Self {
191            id,
192            metadata: vec![0; metadata_size],
193            inflight,
194            completed,
195        }
196    }
197
198    pub fn num_pieces(&self) -> usize {
199        self.inflight.len()
200    }
201
202    pub fn request(&mut self, piece: i32) -> PeerMessage {
203        self.inflight.set(piece as usize, true);
204        let req = MetadataMessage {
205            msg_type: REQUEST,
206            piece,
207            total_size: None,
208        };
209        PeerMessage::Extended {
210            id: self.id,
211            data: bt_bencode::to_vec(&req).expect("valid bencode").into(),
212        }
213    }
214
215    pub fn data(&mut self, piece: i32, metadata_piece: &[u8]) -> PeerMessage {
216        let req = MetadataMessage {
217            msg_type: DATA,
218            piece,
219            // Note this is NOT the length of the piece
220            total_size: Some(self.metadata.len() as i32),
221        };
222        let mut data: Vec<u8> = bt_bencode::to_vec(&req).expect("valid bencode");
223        data.extend_from_slice(metadata_piece);
224        PeerMessage::Extended {
225            id: self.id,
226            data: data.into(),
227        }
228    }
229
230    pub fn reject(&mut self, piece: i32) -> PeerMessage {
231        let req = MetadataMessage {
232            msg_type: REJECT,
233            piece,
234            total_size: None,
235        };
236        PeerMessage::Extended {
237            id: self.id,
238            data: bt_bencode::to_vec(&req).expect("valid bencode").into(),
239        }
240    }
241}
242
243impl ExtensionProtocol for MetadataExtension {
244    fn handle_message<'state>(
245        &mut self,
246        data: Bytes,
247        state: &mut StateRef<'state>,
248        connection: &mut PeerConnection,
249    ) -> Result<(), DisconnectReason> {
250        // TODO: Consider reusing this
251        let mut de = Deserializer::from_slice(&data[..]);
252        let message: MetadataMessage = <MetadataMessage>::deserialize(&mut de)
253            .map_err(|_err| DisconnectReason::ProtocolError("Invalid metadata message"))?;
254        log::trace!(
255            "Got metadata extension message of type: {}",
256            message.msg_type
257        );
258        let piece_idx = usize::try_from(message.piece)
259            .map_err(|_err| DisconnectReason::ProtocolError("Invalid metadata piece index"))?;
260        let Some(start_offset) = piece_idx.checked_mul(SUBPIECE_SIZE as usize) else {
261            return Err(DisconnectReason::ProtocolError(
262                "Invalid metadata piece index",
263            ));
264        };
265        match message.msg_type {
266            REQUEST => {
267                // if we do not have all metadata yet then reject the requests
268                if let Some(metadata) = state.metadata() {
269                    let info_bytes = metadata.construct_info().encode();
270                    if start_offset >= info_bytes.len() {
271                        connection
272                            .outgoing_msgs_buffer
273                            .push(self.reject(message.piece));
274                    } else {
275                        let end = (start_offset + SUBPIECE_SIZE as usize).min(info_bytes.len());
276                        let metadata_piece = &info_bytes[start_offset..end];
277                        connection
278                            .outgoing_msgs_buffer
279                            .push(self.data(message.piece, metadata_piece));
280                    }
281                } else {
282                    connection
283                        .outgoing_msgs_buffer
284                        .push(self.reject(message.piece));
285                }
286                de.end().map_err(|_err| {
287                    DisconnectReason::ProtocolError("Metadata request message longer than expected")
288                })?;
289            }
290            DATA => {
291                // We race compiletion of the metadata so we might
292                // receive DATA messages late, in that case we ignore them
293                if state.is_initialzied() {
294                    return Ok(());
295                }
296                let end = (start_offset + SUBPIECE_SIZE as usize).min(self.metadata.len());
297                let actual_data = &data[de.byte_offset()..];
298                if actual_data.len() < end - start_offset {
299                    return Err(DisconnectReason::ProtocolError("Invalid DATA length"));
300                }
301                connection.network_stats.download_throughput += actual_data.len() as u64;
302                self.metadata[start_offset..end].copy_from_slice(actual_data);
303
304                self.completed.set(piece_idx, true);
305                if let Some(index) = self.inflight.first_zero() {
306                    connection
307                        .outgoing_msgs_buffer
308                        .push(self.request(index as i32));
309                } else if self.completed.all() {
310                    let mut hasher = Sha1::new();
311                    hasher.update(&self.metadata);
312                    let hash = hasher.finalize();
313
314                    if hash.as_slice() != state.info_hash() {
315                        log::error!("Got wrong hash for metadata");
316                        return Err(DisconnectReason::ProtocolError("Metadata hash mismatch"));
317                    } else if state.state().is_none() {
318                        let metadata: Value =
319                            bt_bencode::from_slice(&self.metadata).map_err(|_err| {
320                                DisconnectReason::ProtocolError("Metadata not parsable")
321                            })?;
322                        let mut parsable = BTreeMap::new();
323                        parsable.insert("info", metadata);
324                        let torrent = lava_torrent::torrent::v1::Torrent::read_from_bytes(
325                            // should never panic
326                            bt_bencode::to_vec(&parsable).unwrap().as_slice(),
327                        )
328                        .expect("metadata to be parsable");
329                        state.init(torrent).expect("State to be initialized once");
330                    }
331                }
332            }
333            REJECT => {
334                if piece_idx < self.num_pieces() {
335                    if let Some(index) = self.inflight.first_zero() {
336                        connection
337                            .outgoing_msgs_buffer
338                            .push(self.request(index as i32));
339                    }
340                    self.inflight.set(piece_idx, false);
341                    log::warn!("Got reject request");
342                } else {
343                    log::error!("Got invalid reject request");
344                    return Err(DisconnectReason::ProtocolError("Invalid reject request"));
345                }
346                de.end().map_err(|_err| {
347                    DisconnectReason::ProtocolError("Metadata request message longer than expected")
348                })?;
349            }
350            typ => {
351                if piece_idx < self.num_pieces() {
352                    if let Some(index) = self.inflight.first_zero() {
353                        connection
354                            .outgoing_msgs_buffer
355                            .push(self.request(index as i32));
356                    }
357                    self.inflight.set(piece_idx, false);
358                }
359
360                log::error!("Got metadata extension unknown type: {typ}");
361            }
362        }
363        connection.metadata_progress = Some(MetadataProgress {
364            total_piece: self.num_pieces(),
365            completed_pieces: self.completed.count_ones(),
366        });
367        Ok(())
368    }
369}