webtorrent/
client.rs

1use crate::conn_pool::ConnPool;
2use crate::discovery::Discovery;
3use crate::error::{Result, WebTorrentError};
4use crate::nat::NatTraversal;
5use crate::throttling::ThrottleGroup;
6use crate::torrent::Torrent;
7use bytes::Bytes;
8use rand::Rng;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::{mpsc, RwLock};
12use tokio::time::{Duration, Instant};
13
14/// WebTorrent Client Options
15#[derive(Debug, Clone)]
16pub struct WebTorrentOptions {
17    pub peer_id: Option<[u8; 20]>,
18    pub node_id: Option<[u8; 20]>,
19    pub torrent_port: u16,
20    pub dht_port: u16,
21    pub max_conns: usize,
22    pub utp: bool,
23    pub nat_upnp: bool,
24    pub nat_pmp: bool,
25    pub lsd: bool,
26    pub ut_pex: bool,
27    pub seed_outgoing_connections: bool,
28    pub download_limit: Option<u64>, // bytes per second, None = unlimited
29    pub upload_limit: Option<u64>,   // bytes per second, None = unlimited
30    pub blocklist: Option<String>,
31    pub tracker: Option<TrackerConfig>,
32    pub web_seeds: bool,
33}
34
35impl Default for WebTorrentOptions {
36    fn default() -> Self {
37        Self {
38            peer_id: None,
39            node_id: None,
40            torrent_port: 0,
41            dht_port: 0,
42            max_conns: 55,
43            utp: true,
44            nat_upnp: true,
45            nat_pmp: true,
46            lsd: true,
47            ut_pex: true,
48            seed_outgoing_connections: true,
49            download_limit: None,
50            upload_limit: None,
51            blocklist: None,
52            tracker: None,
53            web_seeds: true,
54        }
55    }
56}
57
58pub struct TrackerConfig {
59    pub announce: Vec<String>,
60    #[cfg_attr(not(test), allow(dead_code))]
61    pub get_announce_opts: Option<Box<dyn Fn() -> HashMap<String, String> + Send + Sync>>,
62}
63
64impl std::fmt::Debug for TrackerConfig {
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("TrackerConfig")
67            .field("announce", &self.announce)
68            .field("get_announce_opts", &"<function>")
69            .finish()
70    }
71}
72
73impl Clone for TrackerConfig {
74    fn clone(&self) -> Self {
75        Self {
76            announce: self.announce.clone(),
77            get_announce_opts: None, // Can't clone Fn trait objects
78        }
79    }
80}
81
82/// WebTorrent Client
83pub struct WebTorrent {
84    pub(crate) peer_id: [u8; 20],
85    pub(crate) node_id: [u8; 20],
86    pub(crate) options: WebTorrentOptions,
87    pub(crate) torrents: Arc<RwLock<Vec<Arc<Torrent>>>>,
88    pub(crate) conn_pool: Arc<RwLock<Option<Arc<ConnPool>>>>,
89    pub(crate) nat_traversal: Option<Arc<NatTraversal>>,
90    pub(crate) dht: Option<Arc<Discovery>>,
91    pub(crate) destroyed: Arc<RwLock<bool>>,
92    pub(crate) listening: Arc<RwLock<bool>>,
93    pub(crate) ready: Arc<RwLock<bool>>,
94    pub(crate) torrent_port: Arc<RwLock<u16>>,
95    pub(crate) dht_port: Arc<RwLock<u16>>,
96    pub(crate) event_tx: mpsc::UnboundedSender<ClientEvent>,
97    pub(crate) event_rx: Arc<RwLock<mpsc::UnboundedReceiver<ClientEvent>>>,
98    // Speed tracking
99    download_speed_tracker: Arc<SpeedTracker>,
100    upload_speed_tracker: Arc<SpeedTracker>,
101    // Throttling
102    download_throttle: Arc<ThrottleGroup>,
103    upload_throttle: Arc<ThrottleGroup>,
104}
105
106/// Speed tracker using a rolling window
107struct SpeedTracker {
108    bytes: Arc<RwLock<Vec<(Instant, u64)>>>, // (timestamp, bytes)
109    window: Duration,
110}
111
112impl SpeedTracker {
113    fn new(window: Duration) -> Self {
114        Self {
115            bytes: Arc::new(RwLock::new(Vec::new())),
116            window,
117        }
118    }
119
120    async fn add_bytes(&self, amount: u64) {
121        let now = Instant::now();
122        let mut bytes = self.bytes.write().await;
123        bytes.push((now, amount));
124        
125        // Remove entries outside the window
126        let cutoff = now.checked_sub(self.window).unwrap_or(Instant::now());
127        bytes.retain(|(time, _)| *time > cutoff);
128    }
129
130    async fn get_speed(&self) -> u64 {
131        let bytes = self.bytes.read().await;
132        if bytes.is_empty() {
133            return 0;
134        }
135
136        let now = Instant::now();
137        let cutoff = now.checked_sub(self.window).unwrap_or(Instant::now());
138        
139        let total_bytes: u64 = bytes.iter()
140            .filter(|(time, _)| *time > cutoff)
141            .map(|(_, amount)| *amount)
142            .sum();
143
144        let oldest_time = bytes.iter()
145            .filter(|(time, _)| *time > cutoff)
146            .map(|(time, _)| *time)
147            .min();
148
149        let elapsed = if let Some(oldest) = oldest_time {
150            now.duration_since(oldest)
151        } else {
152            Duration::from_secs(1) // Default to 1 second if no data
153        };
154        if elapsed.as_secs_f64() > 0.0 {
155            (total_bytes as f64 / elapsed.as_secs_f64()) as u64
156        } else {
157            0
158        }
159    }
160}
161
162#[derive(Clone)]
163pub enum ClientEvent {
164    Ready,
165    Listening,
166    TorrentAdded(Arc<Torrent>),
167    TorrentRemoved(Arc<Torrent>),
168    Error(String), // Store error as string for Clone
169    Download(u64),
170    Upload(u64),
171}
172
173impl WebTorrent {
174    /// Get the client's peer ID
175    pub fn peer_id(&self) -> [u8; 20] {
176        self.peer_id
177    }
178    
179    /// Create a new WebTorrent client
180    pub async fn new(options: WebTorrentOptions) -> Result<Self> {
181        let peer_id = options.peer_id.unwrap_or_else(|| {
182            let mut id = [0u8; 20];
183            id[0..3].copy_from_slice(b"-WW");
184            // Version string (4 bytes)
185            let version_str = format!("{:04}", env!("CARGO_PKG_VERSION_MAJOR").parse::<u16>().unwrap_or(1) * 100 + 
186                env!("CARGO_PKG_VERSION_MINOR").parse::<u16>().unwrap_or(0));
187            let version_bytes = version_str.as_bytes();
188            // Ensure exactly 4 bytes for version
189            if version_bytes.len() >= 4 {
190                id[3..7].copy_from_slice(&version_bytes[0..4]);
191            } else {
192                // Pad with zeros if shorter
193                id[3..3+version_bytes.len()].copy_from_slice(version_bytes);
194            }
195            id[7] = b'-';
196            // Random bytes
197            let mut rng = rand::thread_rng();
198            rng.fill(&mut id[8..]);
199            id
200        });
201
202        let node_id = options.node_id.unwrap_or_else(|| {
203            let mut id = [0u8; 20];
204            let mut rng = rand::thread_rng();
205            rng.fill(&mut id);
206            id
207        });
208
209        let (event_tx, event_rx) = mpsc::unbounded_channel();
210
211        // Initialize speed trackers (1 second window)
212        let download_speed_tracker = Arc::new(SpeedTracker::new(Duration::from_secs(1)));
213        let upload_speed_tracker = Arc::new(SpeedTracker::new(Duration::from_secs(1)));
214
215        // Initialize throttling
216        let download_throttle = Arc::new(ThrottleGroup::new(
217            options.download_limit.unwrap_or(u64::MAX),
218            options.download_limit.is_some(),
219        ));
220        let upload_throttle = Arc::new(ThrottleGroup::new(
221            options.upload_limit.unwrap_or(u64::MAX),
222            options.upload_limit.is_some(),
223        ));
224
225        let mut client = Self {
226            peer_id,
227            node_id,
228            options: options.clone(),
229            torrents: Arc::new(RwLock::new(Vec::new())),
230            conn_pool: Arc::new(RwLock::new(None)),
231            nat_traversal: None,
232            dht: None,
233            destroyed: Arc::new(RwLock::new(false)),
234            listening: Arc::new(RwLock::new(false)),
235            ready: Arc::new(RwLock::new(false)),
236            torrent_port: Arc::new(RwLock::new(options.torrent_port)),
237            dht_port: Arc::new(RwLock::new(options.dht_port)),
238            event_tx,
239            event_rx: Arc::new(RwLock::new(event_rx)),
240            download_speed_tracker,
241            upload_speed_tracker,
242            download_throttle,
243            upload_throttle,
244        };
245
246        // Initialize NAT traversal if enabled
247        if options.nat_upnp || options.nat_pmp {
248            let nat = Arc::new(NatTraversal::new(options.nat_upnp, options.nat_pmp).await?);
249            client.nat_traversal = Some(nat);
250        }
251
252        // Connection pool will be initialized on first use to avoid circular reference
253        // This is set up when the client starts listening
254
255        // DHT will be initialized via libp2p when discovery starts
256        // This will be handled in the discovery module
257
258        Ok(client)
259    }
260
261    /// Add a torrent to the client
262    pub async fn add(&self, torrent_id: impl Into<TorrentId>) -> Result<Arc<Torrent>> {
263        if *self.destroyed.read().await {
264            return Err(WebTorrentError::ClientDestroyed);
265        }
266
267        // Initialize connection pool if not already initialized
268        // This ensures the client is listening on a port before announcing to trackers
269        {
270            let port = *self.torrent_port.read().await;
271            if port > 0 {
272                let mut conn_pool_guard = self.conn_pool.write().await;
273                if conn_pool_guard.is_none() {
274                    // Create a clone of self wrapped in Arc for the connection pool
275                    // Note: This creates a new Arc, but the connection pool will use it
276                    let client_for_pool = Arc::new(self.clone());
277                    match ConnPool::new(client_for_pool).await {
278                        Ok(pool) => {
279                            *conn_pool_guard = Some(Arc::new(pool));
280                            *self.listening.write().await = true;
281                        }
282                        Err(e) => {
283                            eprintln!("Warning: Failed to initialize connection pool: {}. Tracker announcements may not work.", e);
284                            // Still mark as listening if port is configured
285                            *self.listening.write().await = true;
286                        }
287                    }
288                } else {
289                    *self.listening.write().await = true;
290                }
291            }
292        }
293
294        let torrent_id = torrent_id.into();
295        let torrent = Torrent::new(torrent_id, self.clone()).await?;
296
297        // Check for duplicates
298        let info_hash = torrent.info_hash();
299        let torrents = self.torrents.read().await;
300        for existing in torrents.iter() {
301            if existing.info_hash() == info_hash {
302                return Err(WebTorrentError::DuplicateTorrent(hex::encode(info_hash)));
303            }
304        }
305        drop(torrents);
306
307        let torrent = Arc::new(torrent);
308        
309        // Start discovery for the torrent (will announce to trackers)
310        torrent.start_discovery().await?;
311        
312        self.torrents.write().await.push(torrent.clone());
313
314        self.event_tx.send(ClientEvent::TorrentAdded(torrent.clone()))
315            .map_err(|_| WebTorrentError::Network("Event channel closed".to_string()))?;
316
317        Ok(torrent)
318    }
319
320    /// Remove a torrent from the client
321    pub async fn remove(&self, torrent: Arc<Torrent>) -> Result<()> {
322        if *self.destroyed.read().await {
323            return Err(WebTorrentError::ClientDestroyed);
324        }
325
326        let mut torrents = self.torrents.write().await;
327        if let Some(pos) = torrents.iter().position(|t| Arc::ptr_eq(t, &torrent)) {
328            torrents.remove(pos);
329            torrent.destroy().await?;
330            self.event_tx.send(ClientEvent::TorrentRemoved(torrent))
331                .map_err(|_| WebTorrentError::Network("Event channel closed".to_string()))?;
332        }
333
334        Ok(())
335    }
336
337    /// Get a torrent by info hash
338    pub async fn get(&self, info_hash: &[u8; 20]) -> Option<Arc<Torrent>> {
339        let torrents = self.torrents.read().await;
340        torrents.iter().find(|t| t.info_hash() == *info_hash).cloned()
341    }
342
343    /// Seed a file or data
344    pub async fn seed(
345        &self,
346        name: String,
347        data: Bytes,
348        announce: Option<Vec<String>>,
349    ) -> Result<Arc<Torrent>> {
350        if *self.destroyed.read().await {
351            return Err(WebTorrentError::ClientDestroyed);
352        }
353
354        use crate::torrent_creator::TorrentCreator;
355
356        // Use provided announce or default tracker
357        let announce_list = announce.unwrap_or_else(|| {
358            vec!["http://dig-relay-prod.eba-2cmanxbe.us-east-1.elasticbeanstalk.com:8000/announce".to_string()]
359        });
360
361        // Create torrent
362        let creator = TorrentCreator::new()
363            .with_announce(announce_list.clone());
364        
365        let (torrent_file, info_hash) = creator.create_from_data(name.clone(), data.clone()).await?;
366
367        // Check for duplicates
368        if self.get(&info_hash).await.is_some() {
369            return Err(WebTorrentError::DuplicateTorrent(hex::encode(info_hash)));
370        }
371
372        // Add torrent
373        let torrent = self.add(torrent_file).await?;
374
375        // Store the data in the torrent's store
376        // This will be handled by the torrent when it's ready
377
378        Ok(torrent)
379    }
380
381    /// Get download speed in bytes per second
382    pub async fn download_speed(&self) -> u64 {
383        self.download_speed_tracker.get_speed().await
384    }
385
386    /// Get upload speed in bytes per second
387    pub async fn upload_speed(&self) -> u64 {
388        self.upload_speed_tracker.get_speed().await
389    }
390
391    /// Record downloaded bytes for speed tracking
392    #[cfg_attr(test, allow(dead_code))]
393    pub(crate) async fn record_download(&self, bytes: u64) {
394        if bytes > 0 {
395            self.download_speed_tracker.add_bytes(bytes).await;
396            let _ = self.event_tx.send(ClientEvent::Download(bytes));
397        }
398    }
399
400    /// Record uploaded bytes for speed tracking
401    #[allow(dead_code)]
402    pub(crate) async fn record_upload(&self, bytes: u64) {
403        if bytes > 0 {
404            self.upload_speed_tracker.add_bytes(bytes).await;
405            let _ = self.event_tx.send(ClientEvent::Upload(bytes));
406        }
407    }
408
409    /// Get overall progress (0.0 to 1.0)
410    pub async fn progress(&self) -> f64 {
411        let torrents = self.torrents.read().await;
412        let mut total_downloaded = 0u64;
413        let mut total_length = 0u64;
414
415        for torrent in torrents.iter() {
416            if torrent.progress().await < 1.0 {
417                total_downloaded += torrent.downloaded().await;
418                total_length += torrent.length().await;
419            }
420        }
421
422        if total_length == 0 {
423            return 1.0;
424        }
425
426        total_downloaded as f64 / total_length as f64
427    }
428
429    /// Get overall ratio (uploaded / downloaded)
430    pub async fn ratio(&self) -> f64 {
431        let torrents = self.torrents.read().await;
432        let mut total_uploaded = 0u64;
433        let mut total_received = 0u64;
434
435        for torrent in torrents.iter() {
436            total_uploaded += torrent.uploaded().await;
437            total_received += torrent.received().await;
438        }
439
440        if total_received == 0 {
441            return 0.0;
442        }
443
444        total_uploaded as f64 / total_received as f64
445    }
446
447    /// Set download throttle rate (bytes per second, None = unlimited)
448    pub async fn throttle_download(&self, rate: Option<u64>) {
449        if let Some(rate) = rate {
450            self.download_throttle.set_rate(rate).await;
451            self.download_throttle.set_enabled(true).await;
452        } else {
453            self.download_throttle.set_enabled(false).await;
454        }
455    }
456
457    /// Set upload throttle rate (bytes per second, None = unlimited)
458    pub async fn throttle_upload(&self, rate: Option<u64>) {
459        if let Some(rate) = rate {
460            self.upload_throttle.set_rate(rate).await;
461            self.upload_throttle.set_enabled(true).await;
462        } else {
463            self.upload_throttle.set_enabled(false).await;
464        }
465    }
466
467    /// Get download throttle group (for use by peers/wires)
468    #[allow(dead_code)]
469    pub(crate) fn download_throttle(&self) -> Arc<ThrottleGroup> {
470        Arc::clone(&self.download_throttle)
471    }
472
473    /// Get upload throttle group (for use by peers/wires)
474    #[allow(dead_code)]
475    pub(crate) fn upload_throttle(&self) -> Arc<ThrottleGroup> {
476        Arc::clone(&self.upload_throttle)
477    }
478
479    /// Destroy the client
480    pub async fn destroy(&self) -> Result<()> {
481        if *self.destroyed.read().await {
482            return Err(WebTorrentError::ClientDestroyed);
483        }
484
485        *self.destroyed.write().await = true;
486
487        // Destroy all torrents
488        let torrents = self.torrents.read().await.clone();
489        for torrent in torrents {
490            let _ = torrent.destroy().await;
491        }
492
493        // Destroy connection pool
494        if let Some(conn_pool) = self.conn_pool.read().await.as_ref() {
495            conn_pool.destroy().await?;
496        }
497
498        // Destroy NAT traversal
499        if let Some(nat) = &self.nat_traversal {
500            nat.destroy().await?;
501        }
502
503        // Destroy DHT
504        if let Some(dht) = &self.dht {
505            dht.destroy().await?;
506        }
507
508        Ok(())
509    }
510
511    /// Get listening address
512    pub async fn address(&self) -> Option<(String, u16)> {
513        if !*self.listening.read().await {
514            return None;
515        }
516
517        let port = *self.torrent_port.read().await;
518        Some(("0.0.0.0".to_string(), port))
519    }
520}
521
522impl Clone for WebTorrent {
523    fn clone(&self) -> Self {
524        Self {
525            peer_id: self.peer_id,
526            node_id: self.node_id,
527            options: self.options.clone(),
528            torrents: Arc::clone(&self.torrents),
529            conn_pool: Arc::clone(&self.conn_pool),
530            nat_traversal: self.nat_traversal.clone(),
531            dht: self.dht.clone(),
532            destroyed: Arc::clone(&self.destroyed),
533            listening: Arc::clone(&self.listening),
534            ready: Arc::clone(&self.ready),
535            torrent_port: Arc::clone(&self.torrent_port),
536            dht_port: Arc::clone(&self.dht_port),
537            event_tx: self.event_tx.clone(),
538            event_rx: Arc::clone(&self.event_rx),
539            download_speed_tracker: Arc::clone(&self.download_speed_tracker),
540            upload_speed_tracker: Arc::clone(&self.upload_speed_tracker),
541            download_throttle: Arc::clone(&self.download_throttle),
542            upload_throttle: Arc::clone(&self.upload_throttle),
543        }
544    }
545}
546
547/// Torrent identifier - can be info hash, magnet URI, or torrent file data
548#[derive(Debug, Clone)]
549pub enum TorrentId {
550    InfoHash([u8; 20]),
551    MagnetUri(String),
552    TorrentFile(Bytes),
553    Url(String),
554}
555
556impl From<[u8; 20]> for TorrentId {
557    fn from(hash: [u8; 20]) -> Self {
558        TorrentId::InfoHash(hash)
559    }
560}
561
562impl From<String> for TorrentId {
563    fn from(s: String) -> Self {
564        if s.starts_with("magnet:") {
565            TorrentId::MagnetUri(s)
566        } else if s.starts_with("http://") || s.starts_with("https://") {
567            TorrentId::Url(s)
568        } else {
569            // Assume it's a hex-encoded info hash
570            if let Ok(bytes) = hex::decode(&s) {
571                if bytes.len() == 20 {
572                    let mut hash = [0u8; 20];
573                    hash.copy_from_slice(&bytes);
574                    return TorrentId::InfoHash(hash);
575                }
576            }
577            TorrentId::MagnetUri(s) // Fallback to magnet URI
578        }
579    }
580}
581
582impl From<&str> for TorrentId {
583    fn from(s: &str) -> Self {
584        s.to_string().into()
585    }
586}
587
588impl From<Bytes> for TorrentId {
589    fn from(bytes: Bytes) -> Self {
590        TorrentId::TorrentFile(bytes)
591    }
592}
593
594impl From<Vec<u8>> for TorrentId {
595    fn from(bytes: Vec<u8>) -> Self {
596        TorrentId::TorrentFile(bytes.into())
597    }
598}
599