vortex_bittorrent/peer_comm/
extended_protocol.rs1use std::collections::BTreeMap;
2
3use bitvec::{boxed::BitBox, vec::BitVec};
4use bt_bencode::{ByteString, Deserializer, Value};
5use bytes::{Buf, Bytes};
6use serde::{Deserialize, Serialize};
7use sha1::{Digest, Sha1};
8
9use crate::{
10 peer_comm::peer_connection::PeerConnection,
11 piece_selector::SUBPIECE_SIZE,
12 torrent::{Config, StateRef},
13};
14
15use super::{peer_connection::DisconnectReason, peer_protocol::PeerMessage};
16
17pub const UT_METADATA: &str = "ut_metadata";
18pub const UPLOAD_ONLY: &str = "upload_only";
19
20pub fn init_extension<'state>(
21 id: u8,
22 name: &str,
23 handshake_dict: &BTreeMap<ByteString, Value>,
24 state_ref: &mut StateRef<'state>,
25 outgoing_msgs_buffer: &mut Vec<PeerMessage>,
26) -> Result<Option<Box<dyn ExtensionProtocol>>, DisconnectReason> {
27 match name {
28 UT_METADATA => {
29 let Some(metadata_size) = handshake_dict
30 .get("metadata_size".as_bytes())
31 .and_then(|val| val.as_u64())
32 else {
33 return Err(DisconnectReason::ProtocolError("metadata size not valid"));
34 };
35 if let Some(metadata) = state_ref.metadata() {
36 let expected_size = metadata.construct_info().encode().len();
37 if metadata_size != expected_size as u64 {
38 return Err(DisconnectReason::ProtocolError("metadata size not valid"));
39 }
40 }
41
42 let mut metadata = MetadataExtension::new(id, metadata_size as usize);
43 if !state_ref.is_initialzied() {
44 for i in 0..8.min(metadata.num_pieces()) {
45 outgoing_msgs_buffer.push(metadata.request(i as i32));
46 }
47 }
48 Ok(Some(Box::new(metadata)))
49 }
50 UPLOAD_ONLY => Ok(Some(Box::new(UploadOnlyExtension::new(id)))),
51 _ => Ok(None),
52 }
53}
54
55pub const EXTENSIONS: [(&str, u8); 2] = [(UT_METADATA, 1), (UPLOAD_ONLY, 2)];
57
58pub fn extension_handshake_msg(state_ref: &mut StateRef, config: &Config) -> PeerMessage {
61 let mut handshake = BTreeMap::new();
62 let extensions: BTreeMap<_, _> = BTreeMap::from(EXTENSIONS);
63 handshake.insert("m", bt_bencode::value::to_value(&extensions).unwrap());
64 handshake.insert(
65 "v",
66 bt_bencode::value::to_value(&format!("Vortex {}", env!("CARGO_PKG_VERSION"))).unwrap(),
67 );
68 if let Some(listener_port) = state_ref.listener_port {
69 handshake.insert("p", bt_bencode::value::to_value(listener_port).unwrap());
70 }
71 let is_complete = state_ref.state().is_some_and(|state| state.is_complete);
72 if let Some(metadata) = state_ref.metadata() {
73 let metadata_size = metadata.construct_info().encode().len();
74 handshake.insert(
75 "metadata_size",
76 bt_bencode::to_value(&metadata_size).unwrap(),
77 );
78 let upload_only = if is_complete { 1 } else { 0 };
79 handshake.insert(
80 UPLOAD_ONLY,
81 bt_bencode::value::to_value(&upload_only).unwrap(),
82 );
83 }
84 handshake.insert(
85 "reqq",
86 bt_bencode::value::to_value(&config.max_reported_outstanding_requests).unwrap(),
87 );
88
89 PeerMessage::Extended {
90 id: 0,
91 data: bt_bencode::to_vec(&handshake).unwrap().into(),
92 }
93}
94
95pub trait ExtensionProtocol {
96 fn handle_message<'state>(
97 &mut self,
98 data: Bytes,
99 state: &mut StateRef<'state>,
100 connection: &mut PeerConnection,
104 ) -> Result<(), DisconnectReason>;
105
106 fn on_torrent_complete(&mut self, _outgoing_msgs_buffer: &mut Vec<PeerMessage>) {}
107}
108
109#[derive(Debug, Deserialize, Serialize)]
112pub struct UploadOnlyExtension {
113 pub id: u8,
115 pub enabled: bool,
116}
117
118impl UploadOnlyExtension {
119 pub fn new(id: u8) -> Self {
120 Self { id, enabled: false }
121 }
122
123 pub fn upload_only(&mut self, upload_only: bool) -> PeerMessage {
124 let enabled = if upload_only { 1 } else { 0 };
125 PeerMessage::Extended {
126 id: self.id,
127 data: vec![enabled].into(),
128 }
129 }
130}
131
132impl ExtensionProtocol for UploadOnlyExtension {
133 fn handle_message<'state>(
134 &mut self,
135 mut data: Bytes,
136 _state: &mut StateRef<'state>,
137 connection: &mut PeerConnection,
138 ) -> Result<(), DisconnectReason> {
139 let enabled = data
140 .try_get_u8()
141 .map_err(|_err| DisconnectReason::InvalidMessage)?;
142 connection.is_upload_only = enabled > 0;
143 Ok(())
144 }
145
146 fn on_torrent_complete(&mut self, outgoing_msgs_buffer: &mut Vec<PeerMessage>) {
147 outgoing_msgs_buffer.push(self.upload_only(true));
148 }
149}
150
151const REQUEST: u8 = 0;
152const DATA: u8 = 1;
153const REJECT: u8 = 2;
154
155#[derive(Debug, Deserialize, Serialize, PartialEq)]
156pub struct MetadataMessage {
157 pub msg_type: u8,
158 pub piece: i32,
159 #[serde(skip_serializing_if = "Option::is_none")]
160 pub total_size: Option<i32>,
161}
162
163#[derive(Debug, Clone, Copy, PartialEq, Default)]
165pub struct MetadataProgress {
166 pub total_piece: usize,
168 pub completed_pieces: usize,
170}
171
172pub struct MetadataExtension {
174 id: u8,
176 metadata: Vec<u8>,
177 inflight: BitBox,
182 completed: BitBox,
183}
184
185impl MetadataExtension {
186 pub fn new(id: u8, metadata_size: usize) -> Self {
187 let num_pieces = metadata_size.div_ceil(SUBPIECE_SIZE as usize);
188 let inflight: BitBox = BitVec::repeat(false, num_pieces).into();
189 let completed: BitBox = BitVec::repeat(false, num_pieces).into();
190 Self {
191 id,
192 metadata: vec![0; metadata_size],
193 inflight,
194 completed,
195 }
196 }
197
198 pub fn num_pieces(&self) -> usize {
199 self.inflight.len()
200 }
201
202 pub fn request(&mut self, piece: i32) -> PeerMessage {
203 self.inflight.set(piece as usize, true);
204 let req = MetadataMessage {
205 msg_type: REQUEST,
206 piece,
207 total_size: None,
208 };
209 PeerMessage::Extended {
210 id: self.id,
211 data: bt_bencode::to_vec(&req).expect("valid bencode").into(),
212 }
213 }
214
215 pub fn data(&mut self, piece: i32, metadata_piece: &[u8]) -> PeerMessage {
216 let req = MetadataMessage {
217 msg_type: DATA,
218 piece,
219 total_size: Some(self.metadata.len() as i32),
221 };
222 let mut data: Vec<u8> = bt_bencode::to_vec(&req).expect("valid bencode");
223 data.extend_from_slice(metadata_piece);
224 PeerMessage::Extended {
225 id: self.id,
226 data: data.into(),
227 }
228 }
229
230 pub fn reject(&mut self, piece: i32) -> PeerMessage {
231 let req = MetadataMessage {
232 msg_type: REJECT,
233 piece,
234 total_size: None,
235 };
236 PeerMessage::Extended {
237 id: self.id,
238 data: bt_bencode::to_vec(&req).expect("valid bencode").into(),
239 }
240 }
241}
242
243impl ExtensionProtocol for MetadataExtension {
244 fn handle_message<'state>(
245 &mut self,
246 data: Bytes,
247 state: &mut StateRef<'state>,
248 connection: &mut PeerConnection,
249 ) -> Result<(), DisconnectReason> {
250 let mut de = Deserializer::from_slice(&data[..]);
252 let message: MetadataMessage = <MetadataMessage>::deserialize(&mut de)
253 .map_err(|_err| DisconnectReason::ProtocolError("Invalid metadata message"))?;
254 log::trace!(
255 "Got metadata extension message of type: {}",
256 message.msg_type
257 );
258 let piece_idx = usize::try_from(message.piece)
259 .map_err(|_err| DisconnectReason::ProtocolError("Invalid metadata piece index"))?;
260 let Some(start_offset) = piece_idx.checked_mul(SUBPIECE_SIZE as usize) else {
261 return Err(DisconnectReason::ProtocolError(
262 "Invalid metadata piece index",
263 ));
264 };
265 match message.msg_type {
266 REQUEST => {
267 if let Some(metadata) = state.metadata() {
269 let info_bytes = metadata.construct_info().encode();
270 if start_offset >= info_bytes.len() {
271 connection
272 .outgoing_msgs_buffer
273 .push(self.reject(message.piece));
274 } else {
275 let end = (start_offset + SUBPIECE_SIZE as usize).min(info_bytes.len());
276 let metadata_piece = &info_bytes[start_offset..end];
277 connection
278 .outgoing_msgs_buffer
279 .push(self.data(message.piece, metadata_piece));
280 }
281 } else {
282 connection
283 .outgoing_msgs_buffer
284 .push(self.reject(message.piece));
285 }
286 de.end().map_err(|_err| {
287 DisconnectReason::ProtocolError("Metadata request message longer than expected")
288 })?;
289 }
290 DATA => {
291 if state.is_initialzied() {
294 return Ok(());
295 }
296 let end = (start_offset + SUBPIECE_SIZE as usize).min(self.metadata.len());
297 let actual_data = &data[de.byte_offset()..];
298 if actual_data.len() < end - start_offset {
299 return Err(DisconnectReason::ProtocolError("Invalid DATA length"));
300 }
301 connection.network_stats.download_throughput += actual_data.len() as u64;
302 self.metadata[start_offset..end].copy_from_slice(actual_data);
303
304 self.completed.set(piece_idx, true);
305 if let Some(index) = self.inflight.first_zero() {
306 connection
307 .outgoing_msgs_buffer
308 .push(self.request(index as i32));
309 } else if self.completed.all() {
310 let mut hasher = Sha1::new();
311 hasher.update(&self.metadata);
312 let hash = hasher.finalize();
313
314 if hash.as_slice() != state.info_hash() {
315 log::error!("Got wrong hash for metadata");
316 return Err(DisconnectReason::ProtocolError("Metadata hash mismatch"));
317 } else if state.state().is_none() {
318 let metadata: Value =
319 bt_bencode::from_slice(&self.metadata).map_err(|_err| {
320 DisconnectReason::ProtocolError("Metadata not parsable")
321 })?;
322 let mut parsable = BTreeMap::new();
323 parsable.insert("info", metadata);
324 let torrent = lava_torrent::torrent::v1::Torrent::read_from_bytes(
325 bt_bencode::to_vec(&parsable).unwrap().as_slice(),
327 )
328 .expect("metadata to be parsable");
329 state.init(torrent).expect("State to be initialized once");
330 }
331 }
332 }
333 REJECT => {
334 if piece_idx < self.num_pieces() {
335 if let Some(index) = self.inflight.first_zero() {
336 connection
337 .outgoing_msgs_buffer
338 .push(self.request(index as i32));
339 }
340 self.inflight.set(piece_idx, false);
341 log::warn!("Got reject request");
342 } else {
343 log::error!("Got invalid reject request");
344 return Err(DisconnectReason::ProtocolError("Invalid reject request"));
345 }
346 de.end().map_err(|_err| {
347 DisconnectReason::ProtocolError("Metadata request message longer than expected")
348 })?;
349 }
350 typ => {
351 if piece_idx < self.num_pieces() {
352 if let Some(index) = self.inflight.first_zero() {
353 connection
354 .outgoing_msgs_buffer
355 .push(self.request(index as i32));
356 }
357 self.inflight.set(piece_idx, false);
358 }
359
360 log::error!("Got metadata extension unknown type: {typ}");
361 }
362 }
363 connection.metadata_progress = Some(MetadataProgress {
364 total_piece: self.num_pieces(),
365 completed_pieces: self.completed.count_ones(),
366 });
367 Ok(())
368 }
369}