1use crate::errors::{Result, TrustformersError};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex, OnceLock};
5
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
8pub enum GpuBackend {
9 Cuda,
11 Rocm,
13 #[default]
15 Metal,
16 Vulkan,
18 WebGpu,
20 OpenCl,
22 Intel,
24 Cpu,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct GpuDevice {
31 pub id: usize,
32 pub name: String,
33 pub backend: GpuBackend,
34 pub memory_total: u64,
35 pub memory_free: u64,
36 pub compute_capability: Option<String>,
37 pub is_available: bool,
38}
39
40impl GpuDevice {
41 pub fn cpu() -> Self {
43 Self {
44 id: 0,
45 name: "CPU".to_string(),
46 backend: GpuBackend::Cpu,
47 memory_total: 0,
48 memory_free: 0,
49 compute_capability: None,
50 is_available: true,
51 }
52 }
53
54 pub fn supports_tensor_cores(&self) -> bool {
56 matches!(self.backend, GpuBackend::Cuda)
57 && self.compute_capability.as_ref().map(|cc| cc.as_str() >= "7.0").unwrap_or(false)
58 }
59
60 pub fn memory_utilization(&self) -> f32 {
62 if self.memory_total == 0 {
63 0.0
64 } else {
65 1.0 - (self.memory_free as f32 / self.memory_total as f32)
66 }
67 }
68}
69
70#[derive(Debug)]
72pub struct GpuMemoryPool {
73 #[allow(dead_code)]
74 backend: GpuBackend,
75 allocated_blocks: HashMap<usize, u64>,
76 free_blocks: Vec<(usize, u64)>,
77 total_allocated: u64,
78 peak_allocated: u64,
79}
80
81impl GpuMemoryPool {
82 pub fn new(backend: GpuBackend) -> Self {
83 Self {
84 backend,
85 allocated_blocks: HashMap::new(),
86 free_blocks: Vec::new(),
87 total_allocated: 0,
88 peak_allocated: 0,
89 }
90 }
91
92 pub fn allocate(&mut self, size: u64) -> Result<usize> {
94 if let Some(pos) = self.free_blocks.iter().position(|(_, block_size)| *block_size >= size) {
96 let (ptr, block_size) = self.free_blocks.remove(pos);
97 self.allocated_blocks.insert(ptr, size);
98
99 if block_size > size + 1024 {
101 self.free_blocks.push((ptr + size as usize, block_size - size));
102 }
103
104 Ok(ptr)
105 } else {
106 let ptr = self.allocated_blocks.len() + 1;
108 self.allocated_blocks.insert(ptr, size);
109 self.total_allocated += size;
110 self.peak_allocated = self.peak_allocated.max(self.total_allocated);
111 Ok(ptr)
112 }
113 }
114
115 pub fn deallocate(&mut self, ptr: usize) -> Result<()> {
117 if let Some(size) = self.allocated_blocks.remove(&ptr) {
118 self.free_blocks.push((ptr, size));
119 self.total_allocated -= size;
120 Ok(())
121 } else {
122 Err(TrustformersError::tensor_op_error(
123 "Invalid memory pointer",
124 "deallocate",
125 ))
126 }
127 }
128
129 pub fn stats(&self) -> (u64, u64, u64) {
131 (
132 self.total_allocated,
133 self.peak_allocated,
134 self.free_blocks.iter().map(|(_, size)| size).sum(),
135 )
136 }
137}
138
139#[derive(Debug)]
141pub struct GpuContext {
142 pub device: GpuDevice,
143 memory_pool: Arc<Mutex<GpuMemoryPool>>,
144 stream_count: usize,
145 async_enabled: bool,
146}
147
148impl GpuContext {
149 pub fn new(device: GpuDevice) -> Result<Self> {
151 let memory_pool = Arc::new(Mutex::new(GpuMemoryPool::new(device.backend)));
152
153 Ok(Self {
154 device,
155 memory_pool,
156 stream_count: 1,
157 async_enabled: false,
158 })
159 }
160
161 pub fn cpu() -> Self {
163 Self {
164 device: GpuDevice::cpu(),
165 memory_pool: Arc::new(Mutex::new(GpuMemoryPool::new(GpuBackend::Cpu))),
166 stream_count: 1,
167 async_enabled: false,
168 }
169 }
170
171 pub fn enable_async(&mut self, stream_count: usize) {
173 self.async_enabled = true;
174 self.stream_count = stream_count;
175 }
176
177 pub fn allocate(&self, size: u64) -> Result<usize> {
179 let mut pool = self.memory_pool.lock().map_err(|_| {
180 TrustformersError::tensor_op_error("Failed to acquire memory pool lock", "gpu_memory")
181 })?;
182 pool.allocate(size)
183 }
184
185 pub fn deallocate(&self, ptr: usize) -> Result<()> {
187 let mut pool = self.memory_pool.lock().map_err(|_| {
188 TrustformersError::tensor_op_error("Failed to acquire memory pool lock", "gpu_memory")
189 })?;
190 pool.deallocate(ptr)
191 }
192
193 pub fn memory_stats(&self) -> Result<(u64, u64, u64)> {
195 let pool = self.memory_pool.lock().map_err(|_| {
196 TrustformersError::tensor_op_error("Failed to acquire memory pool lock", "gpu_memory")
197 })?;
198 Ok(pool.stats())
199 }
200
201 pub fn synchronize(&self) -> Result<()> {
203 match self.device.backend {
205 GpuBackend::Cuda => {
206 Ok(())
208 },
209 GpuBackend::Rocm => {
210 Ok(())
212 },
213 GpuBackend::Metal => {
214 Ok(())
216 },
217 GpuBackend::Vulkan => {
218 Ok(())
220 },
221 _ => Ok(()),
222 }
223 }
224}
225
226#[derive(Debug)]
228pub struct GpuManager {
229 available_devices: Vec<GpuDevice>,
230 active_contexts: HashMap<usize, Arc<GpuContext>>,
231}
232
233impl GpuManager {
234 pub fn new() -> Self {
235 let available_devices = Self::detect_devices();
236 Self {
237 available_devices,
238 active_contexts: HashMap::new(),
239 }
240 }
241
242 fn detect_devices() -> Vec<GpuDevice> {
244 let mut devices = Vec::new();
245
246 devices.push(GpuDevice::cpu());
248
249 #[cfg(target_os = "macos")]
251 {
252 if let Ok(metal_devices) = Self::detect_metal_devices() {
254 devices.extend(metal_devices);
255 }
256 }
257
258 #[cfg(feature = "cuda")]
259 {
260 if let Ok(cuda_devices) = Self::detect_cuda_devices() {
262 devices.extend(cuda_devices);
263 }
264 }
265
266 #[cfg(feature = "rocm")]
267 {
268 if let Ok(rocm_devices) = Self::detect_rocm_devices() {
270 devices.extend(rocm_devices);
271 }
272 }
273
274 #[cfg(feature = "vulkan")]
275 {
276 if let Ok(vulkan_devices) = Self::detect_vulkan_devices() {
278 devices.extend(vulkan_devices);
279 }
280 }
281
282 devices
283 }
284
285 #[cfg(target_os = "macos")]
286 fn detect_metal_devices() -> Result<Vec<GpuDevice>> {
287 Ok(vec![GpuDevice {
289 id: 1,
290 name: "Apple GPU".to_string(),
291 backend: GpuBackend::Metal,
292 memory_total: 8 * 1024 * 1024 * 1024, memory_free: 6 * 1024 * 1024 * 1024, compute_capability: Some("Metal 3.0".to_string()),
295 is_available: true,
296 }])
297 }
298
299 #[cfg(feature = "cuda")]
300 fn detect_cuda_devices() -> Result<Vec<GpuDevice>> {
301 Ok(vec![GpuDevice {
303 id: 2,
304 name: "NVIDIA GPU".to_string(),
305 backend: GpuBackend::Cuda,
306 memory_total: 12 * 1024 * 1024 * 1024, memory_free: 10 * 1024 * 1024 * 1024, compute_capability: Some("8.6".to_string()),
309 is_available: true,
310 }])
311 }
312
313 #[cfg(feature = "rocm")]
314 fn detect_rocm_devices() -> Result<Vec<GpuDevice>> {
315 let devices = vec![
325 GpuDevice {
327 id: 3,
328 name: "AMD Radeon RX 6800 XT".to_string(),
329 backend: GpuBackend::Rocm,
330 memory_total: 16 * 1024 * 1024 * 1024, memory_free: 14 * 1024 * 1024 * 1024, compute_capability: Some("gfx1030".to_string()), is_available: true,
334 },
335 GpuDevice {
337 id: 4,
338 name: "AMD Radeon RX 7900 XTX".to_string(),
339 backend: GpuBackend::Rocm,
340 memory_total: 24 * 1024 * 1024 * 1024, memory_free: 22 * 1024 * 1024 * 1024, compute_capability: Some("gfx1100".to_string()), is_available: true,
344 },
345 ];
346
347 Ok(devices)
348 }
349
350 #[cfg(feature = "vulkan")]
351 fn detect_vulkan_devices() -> Result<Vec<GpuDevice>> {
352 Ok(vec![GpuDevice {
354 id: 5,
355 name: "Vulkan GPU".to_string(),
356 backend: GpuBackend::Vulkan,
357 memory_total: 8 * 1024 * 1024 * 1024, memory_free: 6 * 1024 * 1024 * 1024, compute_capability: Some("Vulkan 1.3".to_string()),
360 is_available: true,
361 }])
362 }
363
364 pub fn available_devices(&self) -> &[GpuDevice] {
366 &self.available_devices
367 }
368
369 pub fn best_device(&self) -> &GpuDevice {
371 self.available_devices
373 .iter()
374 .filter(|d| d.is_available)
375 .max_by_key(|d| {
376 let backend_score = match d.backend {
377 GpuBackend::Cuda => 100,
378 GpuBackend::Metal => 90,
379 GpuBackend::Vulkan => 80,
380 GpuBackend::Rocm => 70,
381 GpuBackend::OpenCl => 60,
382 GpuBackend::WebGpu => 50,
383 GpuBackend::Intel => 40,
384 GpuBackend::Cpu => 10,
385 };
386 (backend_score, d.memory_total)
387 })
388 .unwrap_or(&self.available_devices[0])
389 }
390
391 pub fn create_context(&mut self, device_id: usize) -> Result<Arc<GpuContext>> {
393 let device =
394 self.available_devices.iter().find(|d| d.id == device_id).cloned().ok_or_else(
395 || {
396 TrustformersError::tensor_op_error(
397 &format!("Device {} not found", device_id),
398 "create_context",
399 )
400 },
401 )?;
402
403 let context = Arc::new(GpuContext::new(device)?);
404 self.active_contexts.insert(device_id, context.clone());
405 Ok(context)
406 }
407
408 pub fn get_or_create_context(&mut self, device_id: Option<usize>) -> Result<Arc<GpuContext>> {
410 let device_id = device_id.unwrap_or_else(|| self.best_device().id);
411
412 if let Some(context) = self.active_contexts.get(&device_id) {
413 Ok(context.clone())
414 } else {
415 self.create_context(device_id)
416 }
417 }
418
419 pub fn list_devices() -> Result<Vec<GpuDevice>> {
421 Ok(Self::detect_devices())
422 }
423}
424
425impl Default for GpuManager {
426 fn default() -> Self {
427 Self::new()
428 }
429}
430
431static GPU_MANAGER: OnceLock<Arc<Mutex<GpuManager>>> = OnceLock::new();
433
434pub fn gpu_manager() -> Arc<Mutex<GpuManager>> {
436 GPU_MANAGER.get_or_init(|| Arc::new(Mutex::new(GpuManager::new()))).clone()
437}
438
439pub fn init_gpu(preferred_backend: Option<GpuBackend>) -> Result<Arc<GpuContext>> {
441 let manager = gpu_manager();
442 let manager_lock = manager.lock().expect("Lock poisoned");
443
444 let device_id = if let Some(backend) = preferred_backend {
445 manager_lock
446 .available_devices()
447 .iter()
448 .find(|d| d.backend == backend && d.is_available)
449 .map(|d| d.id)
450 } else {
451 Some(manager_lock.best_device().id)
452 };
453
454 let device_id = device_id.unwrap_or_else(|| manager_lock.best_device().id);
455 drop(manager_lock); let mut manager_lock = manager.lock().expect("Lock poisoned");
458 manager_lock.get_or_create_context(Some(device_id))
459}
460
461pub trait ToGpu: Sized {
463 type Output;
464
465 fn to_gpu(&self, context: &GpuContext) -> Result<Self::Output>;
467
468 fn to_cpu(&self) -> Result<Self>;
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_gpu_device_creation() {
478 let device = GpuDevice::cpu();
479 assert_eq!(device.backend, GpuBackend::Cpu);
480 assert!(device.is_available);
481 }
482
483 #[test]
484 fn test_memory_pool_allocation() {
485 let mut pool = GpuMemoryPool::new(GpuBackend::Cpu);
486
487 let ptr1 = pool.allocate(1024).expect("operation failed in test");
488 let ptr2 = pool.allocate(2048).expect("operation failed in test");
489
490 assert_ne!(ptr1, ptr2);
491
492 pool.deallocate(ptr1).expect("operation failed in test");
493 pool.deallocate(ptr2).expect("operation failed in test");
494 }
495
496 #[test]
497 fn test_gpu_context_creation() {
498 let device = GpuDevice::cpu();
499 let context = GpuContext::new(device).expect("operation failed in test");
500
501 assert_eq!(context.device.backend, GpuBackend::Cpu);
502 assert!(!context.async_enabled);
503 }
504
505 #[test]
506 fn test_gpu_manager() {
507 let manager = GpuManager::new();
508 assert!(!manager.available_devices().is_empty());
509
510 let best_device = manager.best_device();
511 assert!(best_device.is_available);
512 }
513
514 #[test]
515 fn test_gpu_backend_default() {
516 let backend = GpuBackend::default();
517
518 #[cfg(target_os = "macos")]
519 assert_eq!(backend, GpuBackend::Metal);
520
521 #[cfg(not(target_os = "macos"))]
524 assert!(matches!(
525 backend,
526 GpuBackend::Cuda
527 | GpuBackend::Rocm
528 | GpuBackend::Vulkan
529 | GpuBackend::Metal
530 | GpuBackend::WebGpu
531 | GpuBackend::Cpu
532 ));
533 }
534
535 #[test]
536 fn test_tensor_cores_support() {
537 let cuda_device = GpuDevice {
538 id: 1,
539 name: "RTX 4090".to_string(),
540 backend: GpuBackend::Cuda,
541 memory_total: 24 * 1024 * 1024 * 1024,
542 memory_free: 20 * 1024 * 1024 * 1024,
543 compute_capability: Some("8.9".to_string()),
544 is_available: true,
545 };
546
547 assert!(cuda_device.supports_tensor_cores());
548
549 let old_cuda_device = GpuDevice {
550 id: 2,
551 name: "GTX 1080".to_string(),
552 backend: GpuBackend::Cuda,
553 memory_total: 8 * 1024 * 1024 * 1024,
554 memory_free: 6 * 1024 * 1024 * 1024,
555 compute_capability: Some("6.1".to_string()),
556 is_available: true,
557 };
558
559 assert!(!old_cuda_device.supports_tensor_cores());
560 }
561
562 #[test]
563 fn test_memory_utilization() {
564 let device = GpuDevice {
565 id: 1,
566 name: "Test GPU".to_string(),
567 backend: GpuBackend::Cuda,
568 memory_total: 1000,
569 memory_free: 300,
570 compute_capability: None,
571 is_available: true,
572 };
573
574 assert_eq!(device.memory_utilization(), 0.7);
575 }
576
577 #[test]
578 fn test_gpu_initialization() {
579 let context = init_gpu(None).expect("operation failed in test");
580 assert!(context.device.is_available);
581 }
582
583 #[test]
584 fn test_context_memory_operations() {
585 let context = GpuContext::cpu();
586
587 let ptr = context.allocate(1024).expect("operation failed in test");
588 assert!(ptr > 0);
589
590 let stats = context.memory_stats().expect("operation failed in test");
591 assert_eq!(stats.0, 1024); context.deallocate(ptr).expect("operation failed in test");
594
595 let stats_after = context.memory_stats().expect("operation failed in test");
596 assert_eq!(stats_after.0, 0); }
598
599 #[test]
600 fn test_async_context() {
601 let mut context = GpuContext::cpu();
602 assert!(!context.async_enabled);
603
604 context.enable_async(4);
605 assert!(context.async_enabled);
606 assert_eq!(context.stream_count, 4);
607 }
608
609 #[test]
610 fn test_manager_context_management() {
611 let mut manager = GpuManager::new();
612
613 let context1 = manager.get_or_create_context(Some(0)).expect("operation failed in test");
614 let context2 = manager.get_or_create_context(Some(0)).expect("operation failed in test");
615
616 assert!(Arc::ptr_eq(&context1, &context2));
618 }
619}