Skip to main content

scirs2_core/gpu/memory_management/
transfer.rs

1//! Transfer queue and buffer lifetime management for GPU memory.
2//!
3//! This submodule provides [`TransferQueue`], [`BufferLifetime`],
4//! and [`MemoryPressure`] for managing GPU memory transfers and lifetimes.
5
6use super::pool::{BufferHandle, MemoryError, MemoryResult};
7use crate::gpu::{GpuBuffer, GpuDataType, GpuError};
8use std::collections::{BTreeMap, VecDeque};
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex, Weak};
11use std::time::{Duration, Instant};
12
13/// Transfer queue for optimizing CPU↔GPU transfers
14#[derive(Debug)]
15pub struct TransferQueue {
16    // Pending transfers
17    pending_transfers: Arc<Mutex<VecDeque<Transfer>>>,
18    // Completed transfers
19    completed_count: Arc<AtomicUsize>,
20    // Total bytes transferred
21    total_bytes_transferred: Arc<AtomicUsize>,
22    // Use pinned memory for transfers
23    use_pinned_memory: bool,
24}
25
26/// Transfer operation
27#[derive(Debug, Clone)]
28struct Transfer {
29    id: u64,
30    direction: TransferDirection,
31    size: usize,
32    queued_at: Instant,
33}
34
35/// Transfer direction
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum TransferDirection {
38    HostToDevice,
39    DeviceToHost,
40    DeviceToDevice,
41}
42
43impl TransferQueue {
44    /// Create a new transfer queue
45    pub fn new() -> Self {
46        Self::with_pinned_memory(true)
47    }
48
49    /// Create a transfer queue with pinned memory option
50    pub fn with_pinned_memory(use_pinned_memory: bool) -> Self {
51        Self {
52            pending_transfers: Arc::new(Mutex::new(VecDeque::new())),
53            completed_count: Arc::new(AtomicUsize::new(0)),
54            total_bytes_transferred: Arc::new(AtomicUsize::new(0)),
55            use_pinned_memory,
56        }
57    }
58
59    /// Queue a transfer
60    pub fn queue_transfer(&self, direction: TransferDirection, size: usize) -> MemoryResult<u64> {
61        static COUNTER: AtomicU64 = AtomicU64::new(1);
62        let id = COUNTER.fetch_add(1, Ordering::Relaxed);
63
64        let transfer = Transfer {
65            id,
66            direction,
67            size,
68            queued_at: Instant::now(),
69        };
70
71        let mut pending = self.pending_transfers.lock().map_err(|_| {
72            MemoryError::GpuError(GpuError::Other("Failed to lock transfers".to_string()))
73        })?;
74
75        pending.push_back(transfer);
76        Ok(id)
77    }
78
79    /// Process the next transfer
80    pub fn process_next(&self) -> MemoryResult<Option<u64>> {
81        let mut pending = self.pending_transfers.lock().map_err(|_| {
82            MemoryError::GpuError(GpuError::Other("Failed to lock transfers".to_string()))
83        })?;
84
85        if let Some(transfer) = pending.pop_front() {
86            self.completed_count.fetch_add(1, Ordering::Relaxed);
87            self.total_bytes_transferred
88                .fetch_add(transfer.size, Ordering::Relaxed);
89            Ok(Some(transfer.id))
90        } else {
91            Ok(None)
92        }
93    }
94
95    /// Get the number of pending transfers
96    pub fn pending_count(&self) -> MemoryResult<usize> {
97        let pending = self.pending_transfers.lock().map_err(|_| {
98            MemoryError::GpuError(GpuError::Other("Failed to lock transfers".to_string()))
99        })?;
100        Ok(pending.len())
101    }
102
103    /// Get the number of completed transfers
104    pub fn completed_count(&self) -> usize {
105        self.completed_count.load(Ordering::Relaxed)
106    }
107
108    /// Get total bytes transferred
109    pub fn total_bytes_transferred(&self) -> usize {
110        self.total_bytes_transferred.load(Ordering::Relaxed)
111    }
112
113    /// Check if using pinned memory
114    pub fn uses_pinned_memory(&self) -> bool {
115        self.use_pinned_memory
116    }
117
118    /// Get transfer statistics
119    pub fn statistics(&self) -> TransferStatistics {
120        TransferStatistics {
121            pending_transfers: self.pending_count().unwrap_or(0),
122            completed_transfers: self.completed_count(),
123            total_bytes_transferred: self.total_bytes_transferred(),
124            uses_pinned_memory: self.uses_pinned_memory(),
125        }
126    }
127}
128
129impl Default for TransferQueue {
130    fn default() -> Self {
131        Self::new()
132    }
133}
134
135/// Transfer queue statistics
136#[derive(Debug, Clone)]
137pub struct TransferStatistics {
138    pub pending_transfers: usize,
139    pub completed_transfers: usize,
140    pub total_bytes_transferred: usize,
141    pub uses_pinned_memory: bool,
142}
143
144/// RAII buffer lifetime management
145pub struct BufferLifetime<T: GpuDataType> {
146    handle: BufferHandle,
147    buffer: Arc<GpuBuffer<T>>,
148    pool: Weak<Mutex<BTreeMap<usize, VecDeque<(BufferHandle, Arc<GpuBuffer<T>>)>>>>,
149    size: usize,
150}
151
152impl<T: GpuDataType> BufferLifetime<T> {
153    /// Create a new buffer lifetime guard
154    pub fn new(
155        handle: BufferHandle,
156        buffer: Arc<GpuBuffer<T>>,
157        pool: Weak<Mutex<BTreeMap<usize, VecDeque<(BufferHandle, Arc<GpuBuffer<T>>)>>>>,
158        size: usize,
159    ) -> Self {
160        Self {
161            handle,
162            buffer,
163            pool,
164            size,
165        }
166    }
167
168    /// Get the buffer handle
169    pub fn handle(&self) -> BufferHandle {
170        self.handle
171    }
172
173    /// Get a reference to the buffer
174    pub fn buffer(&self) -> &Arc<GpuBuffer<T>> {
175        &self.buffer
176    }
177}
178
179impl<T: GpuDataType> Drop for BufferLifetime<T> {
180    fn drop(&mut self) {
181        // Return buffer to pool if pool still exists
182        if let Some(pool_arc) = self.pool.upgrade() {
183            if let Ok(mut pool) = pool_arc.lock() {
184                let buffers = pool.entry(self.size).or_insert_with(VecDeque::new);
185                buffers.push_back((self.handle, self.buffer.clone()));
186            }
187        }
188    }
189}
190
191/// Memory pressure tracker
192#[derive(Debug)]
193pub struct MemoryPressure {
194    // Current memory usage
195    current_usage: Arc<AtomicUsize>,
196    // Memory limit
197    memory_limit: usize,
198    // Pressure thresholds
199    warning_threshold: f64,
200    critical_threshold: f64,
201}
202
203impl MemoryPressure {
204    /// Create a new memory pressure tracker
205    pub fn new(memory_limit: usize) -> Self {
206        Self {
207            current_usage: Arc::new(AtomicUsize::new(0)),
208            memory_limit,
209            warning_threshold: 0.7,  // 70%
210            critical_threshold: 0.9, // 90%
211        }
212    }
213
214    /// Record memory allocation
215    pub fn allocate(&self, size: usize) {
216        self.current_usage.fetch_add(size, Ordering::Relaxed);
217    }
218
219    /// Record memory deallocation
220    pub fn deallocate(&self, size: usize) {
221        self.current_usage.fetch_sub(size, Ordering::Relaxed);
222    }
223
224    /// Get current memory usage
225    pub fn current_usage(&self) -> usize {
226        self.current_usage.load(Ordering::Relaxed)
227    }
228
229    /// Get memory usage ratio
230    pub fn usage_ratio(&self) -> f64 {
231        self.current_usage() as f64 / self.memory_limit as f64
232    }
233
234    /// Get memory pressure level
235    pub fn pressure_level(&self) -> MemoryPressureLevel {
236        let ratio = self.usage_ratio();
237
238        if ratio >= self.critical_threshold {
239            MemoryPressureLevel::Critical
240        } else if ratio >= self.warning_threshold {
241            MemoryPressureLevel::Warning
242        } else {
243            MemoryPressureLevel::Normal
244        }
245    }
246
247    /// Check if under memory pressure
248    pub fn is_under_pressure(&self) -> bool {
249        matches!(
250            self.pressure_level(),
251            MemoryPressureLevel::Warning | MemoryPressureLevel::Critical
252        )
253    }
254
255    /// Get available memory
256    pub fn available_memory(&self) -> usize {
257        self.memory_limit.saturating_sub(self.current_usage())
258    }
259}
260
261/// Memory pressure levels
262#[derive(Debug, Clone, Copy, PartialEq, Eq)]
263pub enum MemoryPressureLevel {
264    Normal,
265    Warning,
266    Critical,
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    fn test_transfer_queue() {
274        let queue = TransferQueue::new();
275
276        let id1 = queue
277            .queue_transfer(TransferDirection::HostToDevice, 1024)
278            .expect("Failed to queue transfer");
279        let id2 = queue
280            .queue_transfer(TransferDirection::DeviceToHost, 2048)
281            .expect("Failed to queue transfer");
282
283        assert_eq!(queue.pending_count().expect("Failed to get count"), 2);
284
285        let processed = queue.process_next().expect("Failed to process");
286        assert_eq!(processed, Some(id1));
287
288        assert_eq!(queue.pending_count().expect("Failed to get count"), 1);
289        assert_eq!(queue.completed_count(), 1);
290    }
291
292    #[test]
293    fn test_transfer_queue_statistics() {
294        let queue = TransferQueue::new();
295
296        queue
297            .queue_transfer(TransferDirection::HostToDevice, 1024)
298            .expect("Failed to queue");
299        queue.process_next().expect("Failed to process");
300
301        let stats = queue.statistics();
302        assert_eq!(stats.completed_transfers, 1);
303        assert_eq!(stats.total_bytes_transferred, 1024);
304        assert!(stats.uses_pinned_memory);
305    }
306
307    #[test]
308    fn test_memory_pressure() {
309        let pressure = MemoryPressure::new(10000);
310
311        assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Normal);
312        assert!(!pressure.is_under_pressure());
313
314        pressure.allocate(7500);
315        assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Warning);
316        assert!(pressure.is_under_pressure());
317
318        pressure.allocate(2000);
319        assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Critical);
320
321        pressure.deallocate(5000);
322        assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Normal);
323    }
324
325    #[test]
326    fn test_memory_pressure_available() {
327        let pressure = MemoryPressure::new(10000);
328
329        assert_eq!(pressure.available_memory(), 10000);
330
331        pressure.allocate(3000);
332        assert_eq!(pressure.available_memory(), 7000);
333
334        pressure.deallocate(1000);
335        assert_eq!(pressure.available_memory(), 8000);
336    }
337}