Skip to main content

trustformers_core/
gpu.rs

1use crate::errors::{Result, TrustformersError};
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex, OnceLock};
5
6/// Supported GPU backend types
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
8pub enum GpuBackend {
9    /// NVIDIA CUDA backend
10    Cuda,
11    /// AMD ROCm backend
12    Rocm,
13    /// Apple Metal Performance Shaders
14    #[default]
15    Metal,
16    /// Vulkan compute backend
17    Vulkan,
18    /// WebGPU for browser/WASM
19    WebGpu,
20    /// OpenCL backend
21    OpenCl,
22    /// Intel oneAPI backend
23    Intel,
24    /// CPU fallback
25    Cpu,
26}
27
28/// GPU device information
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct GpuDevice {
31    pub id: usize,
32    pub name: String,
33    pub backend: GpuBackend,
34    pub memory_total: u64,
35    pub memory_free: u64,
36    pub compute_capability: Option<String>,
37    pub is_available: bool,
38}
39
40impl GpuDevice {
41    /// Create a CPU device as fallback
42    pub fn cpu() -> Self {
43        Self {
44            id: 0,
45            name: "CPU".to_string(),
46            backend: GpuBackend::Cpu,
47            memory_total: 0,
48            memory_free: 0,
49            compute_capability: None,
50            is_available: true,
51        }
52    }
53
54    /// Check if this device supports tensor cores (NVIDIA)
55    pub fn supports_tensor_cores(&self) -> bool {
56        matches!(self.backend, GpuBackend::Cuda)
57            && self.compute_capability.as_ref().map(|cc| cc.as_str() >= "7.0").unwrap_or(false)
58    }
59
60    /// Get device memory utilization ratio
61    pub fn memory_utilization(&self) -> f32 {
62        if self.memory_total == 0 {
63            0.0
64        } else {
65            1.0 - (self.memory_free as f32 / self.memory_total as f32)
66        }
67    }
68}
69
70/// GPU memory pool for efficient allocation
71#[derive(Debug)]
72pub struct GpuMemoryPool {
73    #[allow(dead_code)]
74    backend: GpuBackend,
75    allocated_blocks: HashMap<usize, u64>,
76    free_blocks: Vec<(usize, u64)>,
77    total_allocated: u64,
78    peak_allocated: u64,
79}
80
81impl GpuMemoryPool {
82    pub fn new(backend: GpuBackend) -> Self {
83        Self {
84            backend,
85            allocated_blocks: HashMap::new(),
86            free_blocks: Vec::new(),
87            total_allocated: 0,
88            peak_allocated: 0,
89        }
90    }
91
92    /// Allocate memory block
93    pub fn allocate(&mut self, size: u64) -> Result<usize> {
94        // Find a free block of sufficient size
95        if let Some(pos) = self.free_blocks.iter().position(|(_, block_size)| *block_size >= size) {
96            let (ptr, block_size) = self.free_blocks.remove(pos);
97            self.allocated_blocks.insert(ptr, size);
98
99            // Split block if much larger than needed
100            if block_size > size + 1024 {
101                self.free_blocks.push((ptr + size as usize, block_size - size));
102            }
103
104            Ok(ptr)
105        } else {
106            // Allocate new block
107            let ptr = self.allocated_blocks.len() + 1;
108            self.allocated_blocks.insert(ptr, size);
109            self.total_allocated += size;
110            self.peak_allocated = self.peak_allocated.max(self.total_allocated);
111            Ok(ptr)
112        }
113    }
114
115    /// Deallocate memory block
116    pub fn deallocate(&mut self, ptr: usize) -> Result<()> {
117        if let Some(size) = self.allocated_blocks.remove(&ptr) {
118            self.free_blocks.push((ptr, size));
119            self.total_allocated -= size;
120            Ok(())
121        } else {
122            Err(TrustformersError::tensor_op_error(
123                "Invalid memory pointer",
124                "deallocate",
125            ))
126        }
127    }
128
129    /// Get memory statistics
130    pub fn stats(&self) -> (u64, u64, u64) {
131        (
132            self.total_allocated,
133            self.peak_allocated,
134            self.free_blocks.iter().map(|(_, size)| size).sum(),
135        )
136    }
137}
138
139/// GPU context for managing device operations
140#[derive(Debug)]
141pub struct GpuContext {
142    pub device: GpuDevice,
143    memory_pool: Arc<Mutex<GpuMemoryPool>>,
144    stream_count: usize,
145    async_enabled: bool,
146}
147
148impl GpuContext {
149    /// Create a new GPU context for the given device
150    pub fn new(device: GpuDevice) -> Result<Self> {
151        let memory_pool = Arc::new(Mutex::new(GpuMemoryPool::new(device.backend)));
152
153        Ok(Self {
154            device,
155            memory_pool,
156            stream_count: 1,
157            async_enabled: false,
158        })
159    }
160
161    /// Create CPU-only context
162    pub fn cpu() -> Self {
163        Self {
164            device: GpuDevice::cpu(),
165            memory_pool: Arc::new(Mutex::new(GpuMemoryPool::new(GpuBackend::Cpu))),
166            stream_count: 1,
167            async_enabled: false,
168        }
169    }
170
171    /// Enable asynchronous operations
172    pub fn enable_async(&mut self, stream_count: usize) {
173        self.async_enabled = true;
174        self.stream_count = stream_count;
175    }
176
177    /// Allocate device memory
178    pub fn allocate(&self, size: u64) -> Result<usize> {
179        let mut pool = self.memory_pool.lock().map_err(|_| {
180            TrustformersError::tensor_op_error("Failed to acquire memory pool lock", "gpu_memory")
181        })?;
182        pool.allocate(size)
183    }
184
185    /// Deallocate device memory
186    pub fn deallocate(&self, ptr: usize) -> Result<()> {
187        let mut pool = self.memory_pool.lock().map_err(|_| {
188            TrustformersError::tensor_op_error("Failed to acquire memory pool lock", "gpu_memory")
189        })?;
190        pool.deallocate(ptr)
191    }
192
193    /// Get memory statistics
194    pub fn memory_stats(&self) -> Result<(u64, u64, u64)> {
195        let pool = self.memory_pool.lock().map_err(|_| {
196            TrustformersError::tensor_op_error("Failed to acquire memory pool lock", "gpu_memory")
197        })?;
198        Ok(pool.stats())
199    }
200
201    /// Synchronize all operations on this context
202    pub fn synchronize(&self) -> Result<()> {
203        // Platform-specific synchronization would go here
204        match self.device.backend {
205            GpuBackend::Cuda => {
206                // cudaDeviceSynchronize() equivalent
207                Ok(())
208            },
209            GpuBackend::Rocm => {
210                // hipDeviceSynchronize() equivalent
211                Ok(())
212            },
213            GpuBackend::Metal => {
214                // Metal command buffer wait until completed
215                Ok(())
216            },
217            GpuBackend::Vulkan => {
218                // vkQueueWaitIdle() equivalent
219                Ok(())
220            },
221            _ => Ok(()),
222        }
223    }
224}
225
226/// GPU manager for device detection and context creation
227#[derive(Debug)]
228pub struct GpuManager {
229    available_devices: Vec<GpuDevice>,
230    active_contexts: HashMap<usize, Arc<GpuContext>>,
231}
232
233impl GpuManager {
234    pub fn new() -> Self {
235        let available_devices = Self::detect_devices();
236        Self {
237            available_devices,
238            active_contexts: HashMap::new(),
239        }
240    }
241
242    /// Detect available GPU devices
243    fn detect_devices() -> Vec<GpuDevice> {
244        let mut devices = Vec::new();
245
246        // Always add CPU as fallback
247        devices.push(GpuDevice::cpu());
248
249        // Platform-specific device detection
250        #[cfg(target_os = "macos")]
251        {
252            // Detect Metal devices
253            if let Ok(metal_devices) = Self::detect_metal_devices() {
254                devices.extend(metal_devices);
255            }
256        }
257
258        #[cfg(feature = "cuda")]
259        {
260            // Detect CUDA devices
261            if let Ok(cuda_devices) = Self::detect_cuda_devices() {
262                devices.extend(cuda_devices);
263            }
264        }
265
266        #[cfg(feature = "rocm")]
267        {
268            // Detect ROCm devices
269            if let Ok(rocm_devices) = Self::detect_rocm_devices() {
270                devices.extend(rocm_devices);
271            }
272        }
273
274        #[cfg(feature = "vulkan")]
275        {
276            // Detect Vulkan devices
277            if let Ok(vulkan_devices) = Self::detect_vulkan_devices() {
278                devices.extend(vulkan_devices);
279            }
280        }
281
282        devices
283    }
284
285    #[cfg(target_os = "macos")]
286    fn detect_metal_devices() -> Result<Vec<GpuDevice>> {
287        // Stub implementation - would use Metal framework
288        Ok(vec![GpuDevice {
289            id: 1,
290            name: "Apple GPU".to_string(),
291            backend: GpuBackend::Metal,
292            memory_total: 8 * 1024 * 1024 * 1024, // 8GB placeholder
293            memory_free: 6 * 1024 * 1024 * 1024,  // 6GB placeholder
294            compute_capability: Some("Metal 3.0".to_string()),
295            is_available: true,
296        }])
297    }
298
299    #[cfg(feature = "cuda")]
300    fn detect_cuda_devices() -> Result<Vec<GpuDevice>> {
301        // Stub implementation - would use CUDA runtime API
302        Ok(vec![GpuDevice {
303            id: 2,
304            name: "NVIDIA GPU".to_string(),
305            backend: GpuBackend::Cuda,
306            memory_total: 12 * 1024 * 1024 * 1024, // 12GB placeholder
307            memory_free: 10 * 1024 * 1024 * 1024,  // 10GB placeholder
308            compute_capability: Some("8.6".to_string()),
309            is_available: true,
310        }])
311    }
312
313    #[cfg(feature = "rocm")]
314    fn detect_rocm_devices() -> Result<Vec<GpuDevice>> {
315        // ROCm device detection using ROCm Runtime API
316        // This would typically use hipGetDeviceCount() and hipGetDeviceProperties()
317
318        // Simulate ROCm device enumeration
319        // In a real implementation, this would call:
320        // - hipInit() to initialize ROCm
321        // - hipGetDeviceCount() to get number of devices
322        // - hipGetDeviceProperties() for each device
323
324        let devices = vec![
325            // Example for RX 6800 XT
326            GpuDevice {
327                id: 3,
328                name: "AMD Radeon RX 6800 XT".to_string(),
329                backend: GpuBackend::Rocm,
330                memory_total: 16 * 1024 * 1024 * 1024, // 16GB
331                memory_free: 14 * 1024 * 1024 * 1024,  // 14GB
332                compute_capability: Some("gfx1030".to_string()), // RDNA 2
333                is_available: true,
334            },
335            // Example for RX 7900 XTX
336            GpuDevice {
337                id: 4,
338                name: "AMD Radeon RX 7900 XTX".to_string(),
339                backend: GpuBackend::Rocm,
340                memory_total: 24 * 1024 * 1024 * 1024, // 24GB
341                memory_free: 22 * 1024 * 1024 * 1024,  // 22GB
342                compute_capability: Some("gfx1100".to_string()), // RDNA 3
343                is_available: true,
344            },
345        ];
346
347        Ok(devices)
348    }
349
350    #[cfg(feature = "vulkan")]
351    fn detect_vulkan_devices() -> Result<Vec<GpuDevice>> {
352        // Stub implementation - would use Vulkan API
353        Ok(vec![GpuDevice {
354            id: 5,
355            name: "Vulkan GPU".to_string(),
356            backend: GpuBackend::Vulkan,
357            memory_total: 8 * 1024 * 1024 * 1024, // 8GB placeholder
358            memory_free: 6 * 1024 * 1024 * 1024,  // 6GB placeholder
359            compute_capability: Some("Vulkan 1.3".to_string()),
360            is_available: true,
361        }])
362    }
363
364    /// Get all available devices
365    pub fn available_devices(&self) -> &[GpuDevice] {
366        &self.available_devices
367    }
368
369    /// Get the best available device
370    pub fn best_device(&self) -> &GpuDevice {
371        // Prefer GPU over CPU, and newer/more capable GPUs
372        self.available_devices
373            .iter()
374            .filter(|d| d.is_available)
375            .max_by_key(|d| {
376                let backend_score = match d.backend {
377                    GpuBackend::Cuda => 100,
378                    GpuBackend::Metal => 90,
379                    GpuBackend::Vulkan => 80,
380                    GpuBackend::Rocm => 70,
381                    GpuBackend::OpenCl => 60,
382                    GpuBackend::WebGpu => 50,
383                    GpuBackend::Intel => 40,
384                    GpuBackend::Cpu => 10,
385                };
386                (backend_score, d.memory_total)
387            })
388            .unwrap_or(&self.available_devices[0])
389    }
390
391    /// Create context for specified device
392    pub fn create_context(&mut self, device_id: usize) -> Result<Arc<GpuContext>> {
393        let device =
394            self.available_devices.iter().find(|d| d.id == device_id).cloned().ok_or_else(
395                || {
396                    TrustformersError::tensor_op_error(
397                        &format!("Device {} not found", device_id),
398                        "create_context",
399                    )
400                },
401            )?;
402
403        let context = Arc::new(GpuContext::new(device)?);
404        self.active_contexts.insert(device_id, context.clone());
405        Ok(context)
406    }
407
408    /// Get existing context or create new one
409    pub fn get_or_create_context(&mut self, device_id: Option<usize>) -> Result<Arc<GpuContext>> {
410        let device_id = device_id.unwrap_or_else(|| self.best_device().id);
411
412        if let Some(context) = self.active_contexts.get(&device_id) {
413            Ok(context.clone())
414        } else {
415            self.create_context(device_id)
416        }
417    }
418
419    /// List available GPU devices (backward compatibility)
420    pub fn list_devices() -> Result<Vec<GpuDevice>> {
421        Ok(Self::detect_devices())
422    }
423}
424
425impl Default for GpuManager {
426    fn default() -> Self {
427        Self::new()
428    }
429}
430
431/// Global GPU manager instance
432static GPU_MANAGER: OnceLock<Arc<Mutex<GpuManager>>> = OnceLock::new();
433
434/// Get the global GPU manager
435pub fn gpu_manager() -> Arc<Mutex<GpuManager>> {
436    GPU_MANAGER.get_or_init(|| Arc::new(Mutex::new(GpuManager::new()))).clone()
437}
438
439/// Initialize GPU subsystem with optional device preference
440pub fn init_gpu(preferred_backend: Option<GpuBackend>) -> Result<Arc<GpuContext>> {
441    let manager = gpu_manager();
442    let manager_lock = manager.lock().expect("Lock poisoned");
443
444    let device_id = if let Some(backend) = preferred_backend {
445        manager_lock
446            .available_devices()
447            .iter()
448            .find(|d| d.backend == backend && d.is_available)
449            .map(|d| d.id)
450    } else {
451        Some(manager_lock.best_device().id)
452    };
453
454    let device_id = device_id.unwrap_or_else(|| manager_lock.best_device().id);
455    drop(manager_lock); // Release the lock before calling get_or_create_context
456
457    let mut manager_lock = manager.lock().expect("Lock poisoned");
458    manager_lock.get_or_create_context(Some(device_id))
459}
460
461/// Trait for types that can be moved to GPU
462pub trait ToGpu: Sized {
463    type Output;
464
465    /// Move this object to the specified GPU context
466    fn to_gpu(&self, context: &GpuContext) -> Result<Self::Output>;
467
468    /// Move this object back to CPU
469    fn to_cpu(&self) -> Result<Self>;
470}
471
472#[cfg(test)]
473mod tests {
474    use super::*;
475
476    #[test]
477    fn test_gpu_device_creation() {
478        let device = GpuDevice::cpu();
479        assert_eq!(device.backend, GpuBackend::Cpu);
480        assert!(device.is_available);
481    }
482
483    #[test]
484    fn test_memory_pool_allocation() {
485        let mut pool = GpuMemoryPool::new(GpuBackend::Cpu);
486
487        let ptr1 = pool.allocate(1024).expect("operation failed in test");
488        let ptr2 = pool.allocate(2048).expect("operation failed in test");
489
490        assert_ne!(ptr1, ptr2);
491
492        pool.deallocate(ptr1).expect("operation failed in test");
493        pool.deallocate(ptr2).expect("operation failed in test");
494    }
495
496    #[test]
497    fn test_gpu_context_creation() {
498        let device = GpuDevice::cpu();
499        let context = GpuContext::new(device).expect("operation failed in test");
500
501        assert_eq!(context.device.backend, GpuBackend::Cpu);
502        assert!(!context.async_enabled);
503    }
504
505    #[test]
506    fn test_gpu_manager() {
507        let manager = GpuManager::new();
508        assert!(!manager.available_devices().is_empty());
509
510        let best_device = manager.best_device();
511        assert!(best_device.is_available);
512    }
513
514    #[test]
515    fn test_gpu_backend_default() {
516        let backend = GpuBackend::default();
517
518        #[cfg(target_os = "macos")]
519        assert_eq!(backend, GpuBackend::Metal);
520
521        // On non-macOS platforms, any valid backend is acceptable
522        // The actual backend depends on available hardware
523        #[cfg(not(target_os = "macos"))]
524        assert!(matches!(
525            backend,
526            GpuBackend::Cuda
527                | GpuBackend::Rocm
528                | GpuBackend::Vulkan
529                | GpuBackend::Metal
530                | GpuBackend::WebGpu
531                | GpuBackend::Cpu
532        ));
533    }
534
535    #[test]
536    fn test_tensor_cores_support() {
537        let cuda_device = GpuDevice {
538            id: 1,
539            name: "RTX 4090".to_string(),
540            backend: GpuBackend::Cuda,
541            memory_total: 24 * 1024 * 1024 * 1024,
542            memory_free: 20 * 1024 * 1024 * 1024,
543            compute_capability: Some("8.9".to_string()),
544            is_available: true,
545        };
546
547        assert!(cuda_device.supports_tensor_cores());
548
549        let old_cuda_device = GpuDevice {
550            id: 2,
551            name: "GTX 1080".to_string(),
552            backend: GpuBackend::Cuda,
553            memory_total: 8 * 1024 * 1024 * 1024,
554            memory_free: 6 * 1024 * 1024 * 1024,
555            compute_capability: Some("6.1".to_string()),
556            is_available: true,
557        };
558
559        assert!(!old_cuda_device.supports_tensor_cores());
560    }
561
562    #[test]
563    fn test_memory_utilization() {
564        let device = GpuDevice {
565            id: 1,
566            name: "Test GPU".to_string(),
567            backend: GpuBackend::Cuda,
568            memory_total: 1000,
569            memory_free: 300,
570            compute_capability: None,
571            is_available: true,
572        };
573
574        assert_eq!(device.memory_utilization(), 0.7);
575    }
576
577    #[test]
578    fn test_gpu_initialization() {
579        let context = init_gpu(None).expect("operation failed in test");
580        assert!(context.device.is_available);
581    }
582
583    #[test]
584    fn test_context_memory_operations() {
585        let context = GpuContext::cpu();
586
587        let ptr = context.allocate(1024).expect("operation failed in test");
588        assert!(ptr > 0);
589
590        let stats = context.memory_stats().expect("operation failed in test");
591        assert_eq!(stats.0, 1024); // total allocated
592
593        context.deallocate(ptr).expect("operation failed in test");
594
595        let stats_after = context.memory_stats().expect("operation failed in test");
596        assert_eq!(stats_after.0, 0); // total allocated after free
597    }
598
599    #[test]
600    fn test_async_context() {
601        let mut context = GpuContext::cpu();
602        assert!(!context.async_enabled);
603
604        context.enable_async(4);
605        assert!(context.async_enabled);
606        assert_eq!(context.stream_count, 4);
607    }
608
609    #[test]
610    fn test_manager_context_management() {
611        let mut manager = GpuManager::new();
612
613        let context1 = manager.get_or_create_context(Some(0)).expect("operation failed in test");
614        let context2 = manager.get_or_create_context(Some(0)).expect("operation failed in test");
615
616        // Should return the same context for the same device
617        assert!(Arc::ptr_eq(&context1, &context2));
618    }
619}