wow_mpq/io/
async_reader.rs

1//! Optional async I/O support with resource protection
2//!
3//! Provides non-blocking archive operations while maintaining security boundaries
4//! and preventing resource exhaustion in async contexts.
5
6#[cfg(feature = "async")]
7use crate::security::{SecurityLimits, SessionTracker};
8#[cfg(feature = "async")]
9use crate::{Error, Result};
10#[cfg(feature = "async")]
11use std::sync::Arc;
12#[cfg(feature = "async")]
13use std::sync::atomic::{AtomicU64, Ordering};
14#[cfg(feature = "async")]
15use tokio::io::{AsyncRead, AsyncReadExt, AsyncSeek, AsyncSeekExt};
16#[cfg(feature = "async")]
17use tokio::sync::{Mutex, Semaphore};
18#[cfg(feature = "async")]
19use tokio::time::{Duration, Instant, timeout};
20
21/// Configuration for async I/O operations with security limits
22#[cfg(feature = "async")]
23#[derive(Debug, Clone)]
24pub struct AsyncConfig {
25    /// Maximum concurrent operations per session
26    pub max_concurrent_ops: usize,
27    /// Timeout for individual I/O operations
28    pub operation_timeout: Duration,
29    /// Maximum memory usage for async buffers
30    pub max_async_memory: usize,
31    /// Enable detailed async metrics collection
32    pub collect_metrics: bool,
33    /// Maximum number of files that can be extracted concurrently
34    pub max_concurrent_extractions: usize,
35    /// Buffer size for async operations
36    pub buffer_size: usize,
37}
38
39#[cfg(feature = "async")]
40impl Default for AsyncConfig {
41    fn default() -> Self {
42        Self {
43            max_concurrent_ops: 10,
44            operation_timeout: Duration::from_secs(30),
45            max_async_memory: 64 * 1024 * 1024, // 64MB
46            collect_metrics: false,
47            max_concurrent_extractions: 5,
48            buffer_size: 64 * 1024, // 64KB
49        }
50    }
51}
52
53/// Metrics for async operations
54#[cfg(feature = "async")]
55#[derive(Debug, Default)]
56pub struct AsyncMetrics {
57    /// Total number of operations started
58    pub total_operations: AtomicU64,
59    /// Number of completed operations
60    pub completed_operations: AtomicU64,
61    /// Number of cancelled operations
62    pub cancelled_operations: AtomicU64,
63    /// Number of timeout operations
64    pub timeout_operations: AtomicU64,
65    /// Peak memory usage during async operations
66    pub peak_memory_usage: AtomicU64,
67    /// Current active operations
68    pub active_operations: AtomicU64,
69    /// Total bytes read asynchronously
70    pub total_bytes_read: AtomicU64,
71}
72
73#[cfg(feature = "async")]
74impl AsyncMetrics {
75    /// Create new async metrics
76    pub fn new() -> Self {
77        Self::default()
78    }
79
80    /// Record the start of an operation
81    pub fn record_operation_start(&self) {
82        self.total_operations.fetch_add(1, Ordering::Relaxed);
83        self.active_operations.fetch_add(1, Ordering::Relaxed);
84    }
85
86    /// Record operation completion
87    pub fn record_operation_complete(&self, bytes_read: u64) {
88        self.completed_operations.fetch_add(1, Ordering::Relaxed);
89        self.active_operations.fetch_sub(1, Ordering::Relaxed);
90        self.total_bytes_read
91            .fetch_add(bytes_read, Ordering::Relaxed);
92    }
93
94    /// Record operation cancellation
95    pub fn record_operation_cancelled(&self) {
96        self.cancelled_operations.fetch_add(1, Ordering::Relaxed);
97        self.active_operations.fetch_sub(1, Ordering::Relaxed);
98    }
99
100    /// Record operation timeout
101    pub fn record_operation_timeout(&self) {
102        self.timeout_operations.fetch_add(1, Ordering::Relaxed);
103        self.active_operations.fetch_sub(1, Ordering::Relaxed);
104    }
105
106    /// Update peak memory usage
107    pub fn update_peak_memory(&self, current_usage: u64) {
108        let mut current_peak = self.peak_memory_usage.load(Ordering::Relaxed);
109        loop {
110            if current_usage <= current_peak {
111                break;
112            }
113            match self.peak_memory_usage.compare_exchange_weak(
114                current_peak,
115                current_usage,
116                Ordering::Relaxed,
117                Ordering::Relaxed,
118            ) {
119                Ok(_) => break,
120                Err(actual) => current_peak = actual,
121            }
122        }
123    }
124
125    /// Get current statistics
126    pub fn get_stats(&self) -> AsyncOperationStats {
127        AsyncOperationStats {
128            total_operations: self.total_operations.load(Ordering::Relaxed),
129            completed_operations: self.completed_operations.load(Ordering::Relaxed),
130            cancelled_operations: self.cancelled_operations.load(Ordering::Relaxed),
131            timeout_operations: self.timeout_operations.load(Ordering::Relaxed),
132            peak_memory_usage: self.peak_memory_usage.load(Ordering::Relaxed),
133            active_operations: self.active_operations.load(Ordering::Relaxed),
134            total_bytes_read: self.total_bytes_read.load(Ordering::Relaxed),
135        }
136    }
137}
138
139/// Statistics snapshot for async operations
140#[cfg(feature = "async")]
141#[derive(Debug, Clone)]
142pub struct AsyncOperationStats {
143    /// Total number of operations started
144    pub total_operations: u64,
145    /// Number of completed operations
146    pub completed_operations: u64,
147    /// Number of cancelled operations  
148    pub cancelled_operations: u64,
149    /// Number of timeout operations
150    pub timeout_operations: u64,
151    /// Peak memory usage during operations
152    pub peak_memory_usage: u64,
153    /// Current active operations
154    pub active_operations: u64,
155    /// Total bytes read across all operations
156    pub total_bytes_read: u64,
157}
158
159/// Async-aware MPQ archive reader with resource protection
160#[cfg(feature = "async")]
161#[derive(Debug)]
162pub struct AsyncArchiveReader<R> {
163    reader: Arc<Mutex<R>>,
164    config: AsyncConfig,
165    session_tracker: Arc<SessionTracker>,
166    active_operations: Arc<Semaphore>,
167    extraction_semaphore: Arc<Semaphore>,
168    metrics: Arc<AsyncMetrics>,
169    security_limits: SecurityLimits,
170}
171
172#[cfg(feature = "async")]
173impl<R: AsyncRead + AsyncSeek + Unpin + Send + 'static> AsyncArchiveReader<R> {
174    /// Create a new async archive reader with default configuration
175    pub fn new(reader: R, session_tracker: Arc<SessionTracker>) -> Self {
176        Self::with_config(reader, AsyncConfig::default(), session_tracker)
177    }
178
179    /// Create a new async archive reader with custom configuration
180    pub fn with_config(
181        reader: R,
182        config: AsyncConfig,
183        session_tracker: Arc<SessionTracker>,
184    ) -> Self {
185        let active_operations = Arc::new(Semaphore::new(config.max_concurrent_ops));
186        let extraction_semaphore = Arc::new(Semaphore::new(config.max_concurrent_extractions));
187        let metrics = if config.collect_metrics {
188            Arc::new(AsyncMetrics::new())
189        } else {
190            Arc::new(AsyncMetrics::default())
191        };
192
193        Self {
194            reader: Arc::new(Mutex::new(reader)),
195            config,
196            session_tracker,
197            active_operations,
198            extraction_semaphore,
199            metrics,
200            security_limits: SecurityLimits::default(),
201        }
202    }
203
204    /// Create with custom security limits
205    pub fn with_security_limits(
206        reader: R,
207        config: AsyncConfig,
208        session_tracker: Arc<SessionTracker>,
209        security_limits: SecurityLimits,
210    ) -> Self {
211        let mut async_reader = Self::with_config(reader, config, session_tracker);
212        async_reader.security_limits = security_limits;
213        async_reader
214    }
215
216    /// Read data at a specific offset with timeout and resource protection
217    pub async fn read_at(&self, offset: u64, buffer: &mut [u8]) -> Result<usize> {
218        // Acquire operation permit
219        let _permit = self.active_operations.acquire().await.map_err(|_| {
220            self.metrics.record_operation_cancelled();
221            Error::resource_exhaustion("Failed to acquire operation permit - system overloaded")
222        })?;
223
224        self.metrics.record_operation_start();
225
226        // Apply timeout to the entire operation
227        let result = timeout(self.config.operation_timeout, async {
228            // Validate buffer size against security limits
229            if buffer.len() > self.config.max_async_memory {
230                return Err(Error::resource_exhaustion(
231                    "Read buffer exceeds maximum allowed size for async operations",
232                ));
233            }
234
235            // Update memory usage tracking
236            self.metrics.update_peak_memory(buffer.len() as u64);
237
238            // Perform the actual read
239            let mut reader = self.reader.lock().await;
240            reader.seek(std::io::SeekFrom::Start(offset)).await?;
241            let bytes_read = reader.read(buffer).await?;
242
243            Ok(bytes_read)
244        })
245        .await;
246
247        match result {
248            Ok(Ok(bytes_read)) => {
249                self.metrics.record_operation_complete(bytes_read as u64);
250                Ok(bytes_read)
251            }
252            Ok(Err(e)) => {
253                self.metrics.record_operation_cancelled();
254                Err(e)
255            }
256            Err(_) => {
257                self.metrics.record_operation_timeout();
258                Err(Error::resource_exhaustion(
259                    "Async read operation timed out - potential DoS protection activated",
260                ))
261            }
262        }
263    }
264
265    /// Read an exact number of bytes at a specific offset with security validation
266    pub async fn read_exact_at(&self, offset: u64, buffer: &mut [u8]) -> Result<()> {
267        let mut total_read = 0;
268        let mut current_offset = offset;
269
270        while total_read < buffer.len() {
271            let bytes_read = self
272                .read_at(current_offset, &mut buffer[total_read..])
273                .await?;
274
275            if bytes_read == 0 {
276                return Err(Error::invalid_format(
277                    "Unexpected end of file during async read operation",
278                ));
279            }
280
281            total_read += bytes_read;
282            current_offset += bytes_read as u64;
283        }
284
285        Ok(())
286    }
287
288    /// Perform multiple file extractions concurrently with bounded parallelism
289    pub async fn extract_files_concurrent(
290        &self,
291        file_requests: Vec<(String, u64, u64)>, // (filename, offset, size)
292    ) -> Result<Vec<(String, Vec<u8>)>> {
293        if file_requests.len() > self.config.max_concurrent_extractions * 2 {
294            return Err(Error::resource_exhaustion(
295                "Too many concurrent file extraction requests - potential resource exhaustion",
296            ));
297        }
298
299        // Check session limits for all files combined
300        let total_bytes: u64 = file_requests.iter().map(|(_, _, size)| *size).sum();
301        self.session_tracker
302            .check_session_limits_with_addition(total_bytes, &self.security_limits)?;
303
304        let mut handles = Vec::new();
305
306        for (filename, offset, size) in file_requests {
307            // Validate individual file size
308            if size > self.security_limits.max_decompressed_size {
309                return Err(Error::resource_exhaustion(format!(
310                    "File '{}' exceeds maximum size limit for async extraction",
311                    filename
312                )));
313            }
314
315            let reader = Arc::clone(&self.reader);
316            let extraction_permit = Arc::clone(&self.extraction_semaphore);
317            let metrics = Arc::clone(&self.metrics);
318            let config = self.config.clone();
319
320            let handle = tokio::spawn(async move {
321                let _permit = extraction_permit.acquire().await.map_err(|_| {
322                    Error::resource_exhaustion("Failed to acquire extraction permit")
323                })?;
324
325                metrics.record_operation_start();
326
327                let result = timeout(config.operation_timeout, async {
328                    let mut buffer = vec![0u8; size as usize];
329                    let mut reader = reader.lock().await;
330                    reader.seek(std::io::SeekFrom::Start(offset)).await?;
331                    reader.read_exact(&mut buffer).await?;
332                    Ok((filename.clone(), buffer))
333                })
334                .await;
335
336                match result {
337                    Ok(Ok((filename, data))) => {
338                        metrics.record_operation_complete(size);
339                        Ok((filename, data))
340                    }
341                    Ok(Err(e)) => {
342                        metrics.record_operation_cancelled();
343                        Err(e)
344                    }
345                    Err(_) => {
346                        metrics.record_operation_timeout();
347                        Err(Error::resource_exhaustion(format!(
348                            "Extraction of '{}' timed out - potential DoS protection activated",
349                            filename
350                        )))
351                    }
352                }
353            });
354
355            handles.push(handle);
356        }
357
358        // Collect all results
359        let mut results = Vec::new();
360        for handle in handles {
361            let result = handle
362                .await
363                .map_err(|e| Error::resource_exhaustion(format!("Async task failed: {}", e)))??;
364            results.push(result);
365        }
366
367        // Update session tracker with total bytes extracted
368        self.session_tracker.record_decompression(total_bytes);
369
370        Ok(results)
371    }
372
373    /// Create a decompression monitor for async operations
374    pub fn create_decompression_monitor(
375        &self,
376        expected_size: u64,
377        compression_method: u8,
378        file_path: Option<&str>,
379    ) -> Result<Arc<AsyncDecompressionMonitor>> {
380        // Validate decompression request
381        crate::security::validate_decompression_operation(
382            0, // Compressed size not known at this point
383            expected_size,
384            compression_method,
385            file_path,
386            &self.session_tracker,
387            &self.security_limits,
388        )?;
389
390        Ok(Arc::new(AsyncDecompressionMonitor::new(
391            expected_size.min(self.security_limits.max_decompressed_size),
392            self.security_limits.max_decompression_time,
393            self.config.buffer_size,
394        )))
395    }
396
397    /// Get current async operation statistics
398    pub fn get_stats(&self) -> AsyncOperationStats {
399        self.metrics.get_stats()
400    }
401
402    /// Check if the reader is under resource pressure
403    pub fn is_under_pressure(&self) -> bool {
404        let available_ops = self.active_operations.available_permits();
405        let available_extractions = self.extraction_semaphore.available_permits();
406
407        // Consider under pressure if less than 20% permits available
408        // Use max(1, x/5) to ensure we detect pressure even with small limits
409        let ops_threshold = std::cmp::max(1, self.config.max_concurrent_ops / 5);
410        let extraction_threshold = std::cmp::max(1, self.config.max_concurrent_extractions / 5);
411
412        available_ops < ops_threshold || available_extractions < extraction_threshold
413    }
414
415    /// Force cancellation of all pending operations (for cleanup)
416    pub async fn shutdown(&self) -> Result<()> {
417        // Close semaphores to prevent new operations
418        self.active_operations.close();
419        self.extraction_semaphore.close();
420
421        // Give existing operations a moment to complete
422        tokio::time::sleep(Duration::from_millis(100)).await;
423
424        Ok(())
425    }
426}
427
428/// Async-aware decompression monitor with progress tracking
429#[cfg(feature = "async")]
430#[derive(Debug)]
431pub struct AsyncDecompressionMonitor {
432    max_size: u64,
433    max_time: Duration,
434    buffer_size: usize,
435    start_time: Instant,
436    bytes_decompressed: AtomicU64,
437    should_cancel: AtomicU64,
438}
439
440#[cfg(feature = "async")]
441impl AsyncDecompressionMonitor {
442    /// Create a new async decompression monitor
443    pub fn new(max_size: u64, max_time: Duration, buffer_size: usize) -> Self {
444        Self {
445            max_size,
446            max_time,
447            buffer_size,
448            start_time: Instant::now(),
449            bytes_decompressed: AtomicU64::new(0),
450            should_cancel: AtomicU64::new(0),
451        }
452    }
453
454    /// Check if decompression should continue (async-safe)
455    pub async fn check_progress(&self, current_output_size: u64) -> Result<()> {
456        // Check size limits
457        if current_output_size > self.max_size {
458            return Err(Error::resource_exhaustion(
459                "Async decompression size limit exceeded - potential compression bomb",
460            ));
461        }
462
463        // Check time limits
464        if self.start_time.elapsed() > self.max_time {
465            return Err(Error::resource_exhaustion(
466                "Async decompression time limit exceeded - potential DoS attack",
467            ));
468        }
469
470        // Check if cancellation was requested
471        if self.should_cancel.load(Ordering::Relaxed) != 0 {
472            return Err(Error::resource_exhaustion(
473                "Async decompression cancelled due to security limits",
474            ));
475        }
476
477        // Update current progress
478        self.bytes_decompressed
479            .store(current_output_size, Ordering::Relaxed);
480
481        // Yield control to allow other tasks to run
482        tokio::task::yield_now().await;
483
484        Ok(())
485    }
486
487    /// Request cancellation of decompression
488    pub fn request_cancellation(&self) {
489        self.should_cancel.store(1, Ordering::Relaxed);
490    }
491
492    /// Get recommended buffer size for async operations
493    pub fn get_buffer_size(&self) -> usize {
494        self.buffer_size
495    }
496
497    /// Get current statistics
498    pub fn get_stats(&self) -> (u64, Duration) {
499        (
500            self.bytes_decompressed.load(Ordering::Relaxed),
501            self.start_time.elapsed(),
502        )
503    }
504}
505
506#[cfg(test)]
507#[cfg(feature = "async")]
508mod tests {
509    use super::*;
510    use std::io::Cursor;
511    use std::sync::Arc;
512
513    #[tokio::test]
514    async fn test_async_config_default() {
515        let config = AsyncConfig::default();
516        assert_eq!(config.max_concurrent_ops, 10);
517        assert_eq!(config.operation_timeout, Duration::from_secs(30));
518        assert_eq!(config.max_async_memory, 64 * 1024 * 1024);
519        assert!(!config.collect_metrics);
520        assert_eq!(config.max_concurrent_extractions, 5);
521        assert_eq!(config.buffer_size, 64 * 1024);
522    }
523
524    #[tokio::test]
525    async fn test_async_metrics() {
526        let metrics = AsyncMetrics::new();
527
528        metrics.record_operation_start();
529        metrics.record_operation_complete(1024);
530
531        let stats = metrics.get_stats();
532        assert_eq!(stats.total_operations, 1);
533        assert_eq!(stats.completed_operations, 1);
534        assert_eq!(stats.total_bytes_read, 1024);
535        assert_eq!(stats.active_operations, 0);
536    }
537
538    #[tokio::test]
539    async fn test_async_reader_creation() {
540        let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
541        let cursor = Cursor::new(data);
542        let session = Arc::new(SessionTracker::new());
543
544        let reader = AsyncArchiveReader::new(cursor, session);
545        assert!(!reader.is_under_pressure());
546    }
547
548    #[tokio::test]
549    async fn test_async_read_at() {
550        let data = vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15];
551        let cursor = Cursor::new(data);
552        let session = Arc::new(SessionTracker::new());
553
554        let reader = AsyncArchiveReader::new(cursor, session);
555
556        let mut buffer = [0u8; 4];
557        let bytes_read = reader.read_at(5, &mut buffer).await.unwrap();
558        assert_eq!(bytes_read, 4);
559        assert_eq!(buffer, [5, 6, 7, 8]);
560    }
561
562    #[tokio::test]
563    async fn test_async_read_exact_at() {
564        let data = vec![10, 20, 30, 40, 50, 60, 70, 80, 90, 100];
565        let cursor = Cursor::new(data);
566        let session = Arc::new(SessionTracker::new());
567
568        let reader = AsyncArchiveReader::new(cursor, session);
569
570        let mut buffer = [0u8; 3];
571        reader.read_exact_at(2, &mut buffer).await.unwrap();
572        assert_eq!(buffer, [30, 40, 50]);
573    }
574
575    #[tokio::test]
576    async fn test_async_read_oversized_buffer() {
577        let data = vec![1, 2, 3];
578        let cursor = Cursor::new(data);
579        let session = Arc::new(SessionTracker::new());
580
581        let config = AsyncConfig {
582            max_async_memory: 2, // Very small limit
583            ..Default::default()
584        };
585
586        let reader = AsyncArchiveReader::with_config(cursor, config, session);
587
588        let mut buffer = [0u8; 5]; // Exceeds limit
589        let result = reader.read_at(0, &mut buffer).await;
590        assert!(result.is_err());
591        assert!(
592            result
593                .unwrap_err()
594                .to_string()
595                .contains("exceeds maximum allowed size")
596        );
597    }
598
599    #[tokio::test]
600    async fn test_concurrent_file_extraction() {
601        let data = vec![0u8; 1000]; // 1KB of zeros
602        let cursor = Cursor::new(data);
603        let session = Arc::new(SessionTracker::new());
604
605        let reader = AsyncArchiveReader::new(cursor, session);
606
607        let requests = vec![
608            ("file1.txt".to_string(), 0, 100),
609            ("file2.txt".to_string(), 100, 100),
610            ("file3.txt".to_string(), 200, 100),
611        ];
612
613        let results = reader.extract_files_concurrent(requests).await.unwrap();
614        assert_eq!(results.len(), 3);
615        assert_eq!(results[0].0, "file1.txt");
616        assert_eq!(results[0].1.len(), 100);
617    }
618
619    #[tokio::test]
620    async fn test_too_many_concurrent_extractions() {
621        let data = vec![0u8; 1000];
622        let cursor = Cursor::new(data);
623        let session = Arc::new(SessionTracker::new());
624
625        let config = AsyncConfig {
626            max_concurrent_extractions: 2,
627            ..Default::default()
628        };
629
630        let reader = AsyncArchiveReader::with_config(cursor, config, session);
631
632        // Request more than max_concurrent_extractions * 2
633        let requests = (0..6)
634            .map(|i| (format!("file{}.txt", i), i * 100, 50))
635            .collect();
636
637        let result = reader.extract_files_concurrent(requests).await;
638        assert!(result.is_err());
639        assert!(
640            result
641                .unwrap_err()
642                .to_string()
643                .contains("Too many concurrent")
644        );
645    }
646
647    #[tokio::test]
648    async fn test_decompression_monitor() {
649        let monitor = AsyncDecompressionMonitor::new(1024, Duration::from_millis(100), 64);
650
651        // Normal operation should succeed
652        monitor.check_progress(512).await.unwrap();
653
654        // Exceeding size limit should fail
655        let result = monitor.check_progress(2048).await;
656        assert!(result.is_err());
657        assert!(
658            result
659                .unwrap_err()
660                .to_string()
661                .contains("size limit exceeded")
662        );
663    }
664
665    #[tokio::test]
666    async fn test_decompression_monitor_cancellation() {
667        let monitor = AsyncDecompressionMonitor::new(1024, Duration::from_secs(10), 64);
668
669        monitor.request_cancellation();
670
671        let result = monitor.check_progress(512).await;
672        assert!(result.is_err());
673        assert!(result.unwrap_err().to_string().contains("cancelled"));
674    }
675
676    #[tokio::test]
677    async fn test_reader_shutdown() {
678        let data = vec![1, 2, 3, 4, 5];
679        let cursor = Cursor::new(data);
680        let session = Arc::new(SessionTracker::new());
681
682        let reader = AsyncArchiveReader::new(cursor, session);
683
684        // Should shutdown cleanly
685        reader.shutdown().await.unwrap();
686
687        // New operations should fail after shutdown
688        let mut buffer = [0u8; 2];
689        let result = reader.read_at(0, &mut buffer).await;
690        assert!(result.is_err());
691    }
692
693    #[tokio::test]
694    async fn test_resource_pressure_detection() {
695        let data = vec![1, 2, 3, 4, 5];
696        let cursor = Cursor::new(data);
697        let session = Arc::new(SessionTracker::new());
698
699        let config = AsyncConfig {
700            max_concurrent_ops: 1, // Very limited
701            ..Default::default()
702        };
703
704        let reader = AsyncArchiveReader::with_config(cursor, config, session);
705
706        // Initially should not be under pressure
707        assert!(!reader.is_under_pressure());
708
709        // Start an operation that holds the permit
710        let _permit = reader.active_operations.acquire().await.unwrap();
711
712        // Now should be under pressure
713        assert!(reader.is_under_pressure());
714    }
715}