scirs2_core/gpu/memory_management/
transfer.rs1use super::pool::{BufferHandle, MemoryError, MemoryResult};
7use crate::gpu::{GpuBuffer, GpuDataType, GpuError};
8use std::collections::{BTreeMap, VecDeque};
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::{Arc, Mutex, Weak};
11use std::time::{Duration, Instant};
12
13#[derive(Debug)]
15pub struct TransferQueue {
16 pending_transfers: Arc<Mutex<VecDeque<Transfer>>>,
18 completed_count: Arc<AtomicUsize>,
20 total_bytes_transferred: Arc<AtomicUsize>,
22 use_pinned_memory: bool,
24}
25
26#[derive(Debug, Clone)]
28struct Transfer {
29 id: u64,
30 direction: TransferDirection,
31 size: usize,
32 queued_at: Instant,
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub enum TransferDirection {
38 HostToDevice,
39 DeviceToHost,
40 DeviceToDevice,
41}
42
43impl TransferQueue {
44 pub fn new() -> Self {
46 Self::with_pinned_memory(true)
47 }
48
49 pub fn with_pinned_memory(use_pinned_memory: bool) -> Self {
51 Self {
52 pending_transfers: Arc::new(Mutex::new(VecDeque::new())),
53 completed_count: Arc::new(AtomicUsize::new(0)),
54 total_bytes_transferred: Arc::new(AtomicUsize::new(0)),
55 use_pinned_memory,
56 }
57 }
58
59 pub fn queue_transfer(&self, direction: TransferDirection, size: usize) -> MemoryResult<u64> {
61 static COUNTER: AtomicU64 = AtomicU64::new(1);
62 let id = COUNTER.fetch_add(1, Ordering::Relaxed);
63
64 let transfer = Transfer {
65 id,
66 direction,
67 size,
68 queued_at: Instant::now(),
69 };
70
71 let mut pending = self.pending_transfers.lock().map_err(|_| {
72 MemoryError::GpuError(GpuError::Other("Failed to lock transfers".to_string()))
73 })?;
74
75 pending.push_back(transfer);
76 Ok(id)
77 }
78
79 pub fn process_next(&self) -> MemoryResult<Option<u64>> {
81 let mut pending = self.pending_transfers.lock().map_err(|_| {
82 MemoryError::GpuError(GpuError::Other("Failed to lock transfers".to_string()))
83 })?;
84
85 if let Some(transfer) = pending.pop_front() {
86 self.completed_count.fetch_add(1, Ordering::Relaxed);
87 self.total_bytes_transferred
88 .fetch_add(transfer.size, Ordering::Relaxed);
89 Ok(Some(transfer.id))
90 } else {
91 Ok(None)
92 }
93 }
94
95 pub fn pending_count(&self) -> MemoryResult<usize> {
97 let pending = self.pending_transfers.lock().map_err(|_| {
98 MemoryError::GpuError(GpuError::Other("Failed to lock transfers".to_string()))
99 })?;
100 Ok(pending.len())
101 }
102
103 pub fn completed_count(&self) -> usize {
105 self.completed_count.load(Ordering::Relaxed)
106 }
107
108 pub fn total_bytes_transferred(&self) -> usize {
110 self.total_bytes_transferred.load(Ordering::Relaxed)
111 }
112
113 pub fn uses_pinned_memory(&self) -> bool {
115 self.use_pinned_memory
116 }
117
118 pub fn statistics(&self) -> TransferStatistics {
120 TransferStatistics {
121 pending_transfers: self.pending_count().unwrap_or(0),
122 completed_transfers: self.completed_count(),
123 total_bytes_transferred: self.total_bytes_transferred(),
124 uses_pinned_memory: self.uses_pinned_memory(),
125 }
126 }
127}
128
129impl Default for TransferQueue {
130 fn default() -> Self {
131 Self::new()
132 }
133}
134
135#[derive(Debug, Clone)]
137pub struct TransferStatistics {
138 pub pending_transfers: usize,
139 pub completed_transfers: usize,
140 pub total_bytes_transferred: usize,
141 pub uses_pinned_memory: bool,
142}
143
144pub struct BufferLifetime<T: GpuDataType> {
146 handle: BufferHandle,
147 buffer: Arc<GpuBuffer<T>>,
148 pool: Weak<Mutex<BTreeMap<usize, VecDeque<(BufferHandle, Arc<GpuBuffer<T>>)>>>>,
149 size: usize,
150}
151
152impl<T: GpuDataType> BufferLifetime<T> {
153 pub fn new(
155 handle: BufferHandle,
156 buffer: Arc<GpuBuffer<T>>,
157 pool: Weak<Mutex<BTreeMap<usize, VecDeque<(BufferHandle, Arc<GpuBuffer<T>>)>>>>,
158 size: usize,
159 ) -> Self {
160 Self {
161 handle,
162 buffer,
163 pool,
164 size,
165 }
166 }
167
168 pub fn handle(&self) -> BufferHandle {
170 self.handle
171 }
172
173 pub fn buffer(&self) -> &Arc<GpuBuffer<T>> {
175 &self.buffer
176 }
177}
178
179impl<T: GpuDataType> Drop for BufferLifetime<T> {
180 fn drop(&mut self) {
181 if let Some(pool_arc) = self.pool.upgrade() {
183 if let Ok(mut pool) = pool_arc.lock() {
184 let buffers = pool.entry(self.size).or_insert_with(VecDeque::new);
185 buffers.push_back((self.handle, self.buffer.clone()));
186 }
187 }
188 }
189}
190
191#[derive(Debug)]
193pub struct MemoryPressure {
194 current_usage: Arc<AtomicUsize>,
196 memory_limit: usize,
198 warning_threshold: f64,
200 critical_threshold: f64,
201}
202
203impl MemoryPressure {
204 pub fn new(memory_limit: usize) -> Self {
206 Self {
207 current_usage: Arc::new(AtomicUsize::new(0)),
208 memory_limit,
209 warning_threshold: 0.7, critical_threshold: 0.9, }
212 }
213
214 pub fn allocate(&self, size: usize) {
216 self.current_usage.fetch_add(size, Ordering::Relaxed);
217 }
218
219 pub fn deallocate(&self, size: usize) {
221 self.current_usage.fetch_sub(size, Ordering::Relaxed);
222 }
223
224 pub fn current_usage(&self) -> usize {
226 self.current_usage.load(Ordering::Relaxed)
227 }
228
229 pub fn usage_ratio(&self) -> f64 {
231 self.current_usage() as f64 / self.memory_limit as f64
232 }
233
234 pub fn pressure_level(&self) -> MemoryPressureLevel {
236 let ratio = self.usage_ratio();
237
238 if ratio >= self.critical_threshold {
239 MemoryPressureLevel::Critical
240 } else if ratio >= self.warning_threshold {
241 MemoryPressureLevel::Warning
242 } else {
243 MemoryPressureLevel::Normal
244 }
245 }
246
247 pub fn is_under_pressure(&self) -> bool {
249 matches!(
250 self.pressure_level(),
251 MemoryPressureLevel::Warning | MemoryPressureLevel::Critical
252 )
253 }
254
255 pub fn available_memory(&self) -> usize {
257 self.memory_limit.saturating_sub(self.current_usage())
258 }
259}
260
261#[derive(Debug, Clone, Copy, PartialEq, Eq)]
263pub enum MemoryPressureLevel {
264 Normal,
265 Warning,
266 Critical,
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 fn test_transfer_queue() {
274 let queue = TransferQueue::new();
275
276 let id1 = queue
277 .queue_transfer(TransferDirection::HostToDevice, 1024)
278 .expect("Failed to queue transfer");
279 let id2 = queue
280 .queue_transfer(TransferDirection::DeviceToHost, 2048)
281 .expect("Failed to queue transfer");
282
283 assert_eq!(queue.pending_count().expect("Failed to get count"), 2);
284
285 let processed = queue.process_next().expect("Failed to process");
286 assert_eq!(processed, Some(id1));
287
288 assert_eq!(queue.pending_count().expect("Failed to get count"), 1);
289 assert_eq!(queue.completed_count(), 1);
290 }
291
292 #[test]
293 fn test_transfer_queue_statistics() {
294 let queue = TransferQueue::new();
295
296 queue
297 .queue_transfer(TransferDirection::HostToDevice, 1024)
298 .expect("Failed to queue");
299 queue.process_next().expect("Failed to process");
300
301 let stats = queue.statistics();
302 assert_eq!(stats.completed_transfers, 1);
303 assert_eq!(stats.total_bytes_transferred, 1024);
304 assert!(stats.uses_pinned_memory);
305 }
306
307 #[test]
308 fn test_memory_pressure() {
309 let pressure = MemoryPressure::new(10000);
310
311 assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Normal);
312 assert!(!pressure.is_under_pressure());
313
314 pressure.allocate(7500);
315 assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Warning);
316 assert!(pressure.is_under_pressure());
317
318 pressure.allocate(2000);
319 assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Critical);
320
321 pressure.deallocate(5000);
322 assert_eq!(pressure.pressure_level(), MemoryPressureLevel::Normal);
323 }
324
325 #[test]
326 fn test_memory_pressure_available() {
327 let pressure = MemoryPressure::new(10000);
328
329 assert_eq!(pressure.available_memory(), 10000);
330
331 pressure.allocate(3000);
332 assert_eq!(pressure.available_memory(), 7000);
333
334 pressure.deallocate(1000);
335 assert_eq!(pressure.available_memory(), 8000);
336 }
337}