1use crate::error::{ClusteringError, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::time::Instant;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum MemoryStrategy {
14 Unified,
16 Explicit,
18 Pooled { pool_size_mb: usize },
20 ZeroCopy,
22 Adaptive,
24}
25
26impl Default for MemoryStrategy {
27 fn default() -> Self {
28 MemoryStrategy::Adaptive
29 }
30}
31
32#[derive(Debug)]
34pub struct GpuMemoryManager {
35 pools: HashMap<usize, Vec<GpuMemoryBlock>>,
37 total_allocated: usize,
39 peak_usage: usize,
41 alignment: usize,
43 max_pool_size: usize,
45 stats: MemoryStats,
47}
48
49#[derive(Debug, Clone)]
51pub struct GpuMemoryBlock {
52 pub device_ptr: usize,
54 pub size: usize,
56 pub in_use: bool,
58 pub allocated_at: Instant,
60}
61
62#[derive(Debug, Clone, Default, Serialize, Deserialize)]
64pub struct MemoryStats {
65 pub total_allocated: usize,
67 pub peak_usage: usize,
69 pub allocation_count: usize,
71 pub deallocation_count: usize,
73 pub pool_hits: usize,
75 pub pool_misses: usize,
77}
78
79impl GpuMemoryManager {
80 pub fn new(alignment: usize, max_pool_size: usize) -> Self {
82 Self {
83 pools: HashMap::new(),
84 total_allocated: 0,
85 peak_usage: 0,
86 alignment,
87 max_pool_size,
88 stats: MemoryStats::default(),
89 }
90 }
91
92 pub fn allocate(&mut self, size: usize) -> Result<GpuMemoryBlock> {
94 let aligned_size = (size + self.alignment - 1) & !(self.alignment - 1);
95 let size_class = self.get_size_class(aligned_size);
96
97 self.stats.allocation_count += 1;
98
99 if let Some(pool) = self.pools.get_mut(&size_class) {
101 for block in pool.iter_mut() {
102 if !block.in_use && block.size >= aligned_size {
103 block.in_use = true;
104 self.stats.pool_hits += 1;
105 return Ok(GpuMemoryBlock {
106 device_ptr: block.device_ptr,
107 size: block.size,
108 in_use: true,
109 allocated_at: Instant::now(),
110 });
111 }
112 }
113 }
114
115 self.stats.pool_misses += 1;
117 let device_ptr = self.allocate_device_memory(aligned_size)?;
118 self.total_allocated += aligned_size;
119 self.peak_usage = self.peak_usage.max(self.total_allocated);
120 self.stats.total_allocated = self.total_allocated;
121 self.stats.peak_usage = self.peak_usage;
122
123 Ok(GpuMemoryBlock {
124 device_ptr,
125 size: aligned_size,
126 in_use: true,
127 allocated_at: Instant::now(),
128 })
129 }
130
131 pub fn deallocate(&mut self, mut block: GpuMemoryBlock) -> Result<()> {
133 block.in_use = false;
134 self.stats.deallocation_count += 1;
135
136 let size_class = self.get_size_class(block.size);
137
138 let pool = self.pools.entry(size_class).or_insert_with(Vec::new);
139 if pool.len() < self.max_pool_size {
140 pool.push(block);
141 } else {
142 self.free_device_memory(block.device_ptr, block.size)?;
144 self.total_allocated -= block.size;
145 self.stats.total_allocated = self.total_allocated;
146 }
147
148 Ok(())
149 }
150
151 pub fn clear_pools(&mut self) -> Result<()> {
153 for pool in self.pools.values() {
154 for block in pool {
155 if !block.in_use {
156 self.free_device_memory(block.device_ptr, block.size)?;
157 self.total_allocated -= block.size;
158 }
159 }
160 }
161 self.pools.clear();
162 self.stats.total_allocated = self.total_allocated;
163 Ok(())
164 }
165
166 pub fn get_stats(&self) -> &MemoryStats {
168 &self.stats
169 }
170
171 pub fn pool_efficiency(&self) -> f64 {
173 if self.stats.allocation_count == 0 {
174 0.0
175 } else {
176 self.stats.pool_hits as f64 / self.stats.allocation_count as f64
177 }
178 }
179
180 pub fn current_usage(&self) -> usize {
182 self.total_allocated
183 }
184
185 pub fn peak_usage(&self) -> usize {
187 self.peak_usage
188 }
189
190 fn get_size_class(&self, size: usize) -> usize {
192 if size == 0 {
194 return 1;
195 }
196 let mut class = 1;
197 while class < size {
198 class <<= 1;
199 }
200 class
201 }
202
203 fn allocate_device_memory(&self, size: usize) -> Result<usize> {
205 if size == 0 {
209 return Err(ClusteringError::InvalidInput(
210 "Cannot allocate zero bytes".to_string(),
211 ));
212 }
213
214 if size > 16 * 1024 * 1024 * 1024 {
215 return Err(ClusteringError::InvalidInput(
216 "Allocation too large".to_string(),
217 ));
218 }
219
220 Ok(0x1000_0000 + size) }
223
224 fn free_device_memory(&self, _device_ptr: usize, _size: usize) -> Result<()> {
226 Ok(())
229 }
230}
231
232impl MemoryStats {
233 pub fn allocation_efficiency(&self) -> f64 {
235 if self.allocation_count == 0 {
236 1.0
237 } else {
238 self.deallocation_count as f64 / self.allocation_count as f64
239 }
240 }
241
242 pub fn average_allocation_size(&self) -> f64 {
244 if self.allocation_count == 0 {
245 0.0
246 } else {
247 self.total_allocated as f64 / self.allocation_count as f64
248 }
249 }
250
251 pub fn has_potential_leaks(&self) -> bool {
253 self.allocation_count > self.deallocation_count }
255}
256
257#[derive(Debug, Clone)]
259pub enum MemoryTransfer {
260 HostToDevice {
262 host_ptr: *const u8,
264 device_ptr: usize,
266 size: usize,
268 },
269 DeviceToHost {
271 device_ptr: usize,
273 host_ptr: *mut u8,
275 size: usize,
277 },
278 DeviceToDevice {
280 src_device_ptr: usize,
282 dst_device_ptr: usize,
284 size: usize,
286 },
287}
288
289impl MemoryTransfer {
290 pub fn size(&self) -> usize {
292 match self {
293 MemoryTransfer::HostToDevice { size, .. } => *size,
294 MemoryTransfer::DeviceToHost { size, .. } => *size,
295 MemoryTransfer::DeviceToDevice { size, .. } => *size,
296 }
297 }
298
299 pub fn execute(&self) -> Result<()> {
301 match self {
304 MemoryTransfer::HostToDevice { .. } => {
305 Ok(())
307 }
308 MemoryTransfer::DeviceToHost { .. } => {
309 Ok(())
311 }
312 MemoryTransfer::DeviceToDevice { .. } => {
313 Ok(())
315 }
316 }
317 }
318}
319
320#[derive(Debug, Clone, Default)]
322pub struct BandwidthMonitor {
323 pub total_transferred: usize,
325 pub transfer_count: usize,
327 pub total_time_us: u64,
329}
330
331impl BandwidthMonitor {
332 pub fn record_transfer(&mut self, size: usize, duration_us: u64) {
334 self.total_transferred += size;
335 self.transfer_count += 1;
336 self.total_time_us += duration_us;
337 }
338
339 pub fn average_bandwidth_gbps(&self) -> f64 {
341 if self.total_time_us == 0 {
342 0.0
343 } else {
344 let total_gb = self.total_transferred as f64 / (1024.0 * 1024.0 * 1024.0);
345 let total_seconds = self.total_time_us as f64 / 1_000_000.0;
346 total_gb / total_seconds
347 }
348 }
349
350 pub fn average_transfer_size(&self) -> f64 {
352 if self.transfer_count == 0 {
353 0.0
354 } else {
355 self.total_transferred as f64 / self.transfer_count as f64
356 }
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363
364 #[test]
365 fn test_memory_manager_creation() {
366 let manager = GpuMemoryManager::new(256, 10);
367 assert_eq!(manager.alignment, 256);
368 assert_eq!(manager.max_pool_size, 10);
369 assert_eq!(manager.current_usage(), 0);
370 }
371
372 #[test]
373 fn test_memory_allocation() {
374 let mut manager = GpuMemoryManager::new(256, 10);
375
376 let block = manager.allocate(1024).expect("Operation failed");
377 assert!(block.size >= 1024);
378 assert!(block.in_use);
379
380 assert_eq!(manager.get_stats().allocation_count, 1);
381 assert_eq!(manager.get_stats().pool_misses, 1);
382 }
383
384 #[test]
385 fn test_memory_pooling() {
386 let mut manager = GpuMemoryManager::new(256, 10);
387
388 let block = manager.allocate(1024).expect("Operation failed");
390 manager.deallocate(block).expect("Operation failed");
391
392 let _block2 = manager.allocate(1024).expect("Operation failed");
394 assert!(manager.get_stats().pool_hits > 0);
395 }
396
397 #[test]
398 fn test_memory_stats() {
399 let stats = MemoryStats {
400 allocation_count: 10,
401 deallocation_count: 8,
402 total_allocated: 1024,
403 pool_hits: 5,
404 ..Default::default()
405 };
406
407 assert_eq!(stats.allocation_efficiency(), 0.8);
408 assert_eq!(stats.average_allocation_size(), 102.4);
409 assert!(stats.has_potential_leaks());
410 }
411
412 #[test]
413 fn test_bandwidth_monitor() {
414 let mut monitor = BandwidthMonitor::default();
415
416 monitor.record_transfer(1024 * 1024 * 1024, 1_000_000);
418
419 assert_eq!(monitor.average_bandwidth_gbps(), 1.0);
420 assert_eq!(monitor.average_transfer_size(), 1024.0 * 1024.0 * 1024.0);
421 }
422
423 #[test]
424 fn test_memory_transfer() {
425 let transfer = MemoryTransfer::HostToDevice {
426 host_ptr: std::ptr::null(),
427 device_ptr: 0x1000,
428 size: 1024,
429 };
430
431 assert_eq!(transfer.size(), 1024);
432 assert!(transfer.execute().is_ok());
433 }
434}