Skip to main content

trustformers_core/
hardware_acceleration.rs

1// Copyright (c) 2025-2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Hardware acceleration integration for TrustformeRS
5//!
6//! This module provides a unified interface for hardware acceleration backends,
7//! automatically selecting the best available acceleration method based on system
8//! capabilities and user preferences.
9
10#![allow(unused_variables)] // Multi-backend implementation with feature gates
11
12#[allow(unused_imports)] // Used conditionally based on feature gates
13use crate::errors::{acceleration_error, hardware_error, tensor_op_error, Result};
14use crate::tensor::Tensor;
15use std::sync::OnceLock;
16
17/// Hardware acceleration backend types
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19pub enum AccelerationBackend {
20    /// NVIDIA CUDA
21    Cuda,
22    /// AMD ROCm
23    Rocm,
24    /// Intel OneAPI
25    Intel,
26    /// Vulkan Compute
27    Vulkan,
28    /// Apple Metal
29    Metal,
30    /// CPU fallback
31    Cpu,
32}
33
34/// Hardware acceleration configuration
35#[derive(Debug, Clone)]
36pub struct AccelerationConfig {
37    /// Preferred backend (if available)
38    pub preferred_backend: Option<AccelerationBackend>,
39    /// Enable automatic fallback to CPU
40    pub auto_fallback: bool,
41    /// Memory pool size per device (MB)
42    pub memory_pool_size: usize,
43    /// Enable kernel caching
44    pub enable_kernel_cache: bool,
45    /// Enable performance monitoring
46    pub enable_monitoring: bool,
47}
48
49impl Default for AccelerationConfig {
50    fn default() -> Self {
51        Self {
52            preferred_backend: None,
53            auto_fallback: true,
54            memory_pool_size: 1024, // 1GB
55            enable_kernel_cache: true,
56            enable_monitoring: true,
57        }
58    }
59}
60
61/// Hardware acceleration manager
62pub struct HardwareAccelerator {
63    /// Active backend
64    active_backend: AccelerationBackend,
65    /// Configuration
66    #[allow(dead_code)]
67    config: AccelerationConfig,
68    /// Performance statistics
69    stats: AccelerationStats,
70}
71
72/// Performance statistics
73#[derive(Debug, Clone, Default)]
74pub struct AccelerationStats {
75    /// Total operations executed
76    pub total_operations: u64,
77    /// Total execution time (ms)
78    pub total_time_ms: f64,
79    /// Memory allocated (bytes)
80    pub memory_allocated: u64,
81    /// Cache hits
82    pub cache_hits: u64,
83    /// Cache misses
84    pub cache_misses: u64,
85}
86
87/// Global hardware accelerator instance
88static ACCELERATOR: OnceLock<HardwareAccelerator> = OnceLock::new();
89
90impl HardwareAccelerator {
91    /// Initialize hardware accelerator with configuration
92    pub fn initialize(config: AccelerationConfig) -> Result<&'static HardwareAccelerator> {
93        ACCELERATOR.get_or_init(|| {
94            Self::new(config).unwrap_or_else(|_| {
95                // Fallback to CPU if initialization fails
96                Self::new_cpu_fallback()
97            })
98        });
99        Ok(ACCELERATOR.get().expect("accelerator should be initialized after get_or_init"))
100    }
101
102    /// Get global hardware accelerator instance
103    pub fn global() -> Result<&'static HardwareAccelerator> {
104        ACCELERATOR.get().ok_or_else(|| {
105            hardware_error("unknown", "Hardware accelerator not initialized")
106                .with_operation("global")
107                .with_suggestion("Call HardwareAccelerator::initialize() first")
108        })
109    }
110
111    /// Create new hardware accelerator
112    fn new(config: AccelerationConfig) -> Result<Self> {
113        let backend = Self::select_backend(&config)?;
114
115        // Initialize the selected backend
116        match backend {
117            AccelerationBackend::Cuda => {
118                #[cfg(feature = "cuda")]
119                {
120                    crate::kernels::cuda_impl::api::init_cuda()?;
121                }
122                #[cfg(not(feature = "cuda"))]
123                {
124                    return Err(
125                        acceleration_error("CUDA", "Support not compiled in this build")
126                            .with_operation("initialization")
127                            .with_suggestion("Rebuild with --features cuda to enable CUDA support"),
128                    );
129                }
130            },
131            AccelerationBackend::Rocm => {
132                #[cfg(feature = "rocm")]
133                {
134                    crate::kernels::rocm_impl::api::init_rocm()?;
135                }
136                #[cfg(not(feature = "rocm"))]
137                {
138                    return Err(
139                        acceleration_error("ROCm", "Support not compiled in this build")
140                            .with_operation("initialization")
141                            .with_suggestion("Rebuild with --features rocm to enable ROCm support"),
142                    );
143                }
144            },
145            AccelerationBackend::Intel => {
146                #[cfg(feature = "intel")]
147                {
148                    crate::kernels::intel_impl::api::init_intel()?;
149                }
150                #[cfg(not(feature = "intel"))]
151                {
152                    return Err(acceleration_error(
153                        "Intel OneAPI",
154                        "Support not compiled in this build",
155                    )
156                    .with_operation("initialization")
157                    .with_suggestion(
158                        "Rebuild with --features intel to enable Intel OneAPI support",
159                    ));
160                }
161            },
162            AccelerationBackend::Vulkan => {
163                #[cfg(feature = "vulkan")]
164                {
165                    // Vulkan backend initialization is handled in VulkanImpl::new()
166                    let _vulkan = crate::kernels::vulkan_impl::VulkanImpl::new()?;
167                }
168                #[cfg(not(feature = "vulkan"))]
169                {
170                    return Err(
171                        acceleration_error("Vulkan", "Support not compiled in this build")
172                            .with_operation("initialization")
173                            .with_suggestion(
174                                "Rebuild with --features vulkan to enable Vulkan support",
175                            ),
176                    );
177                }
178            },
179            AccelerationBackend::Metal => {
180                #[cfg(all(target_os = "macos", feature = "metal"))]
181                {
182                    // Metal backend initialization using Metal Performance Shaders
183                    let _metal = crate::kernels::metal_impl::MetalImpl::new()?;
184                    log::info!(
185                        "Metal backend initialized successfully for Apple Silicon acceleration"
186                    );
187                }
188                #[cfg(not(all(target_os = "macos", feature = "metal")))]
189                {
190                    return Err(
191                        acceleration_error("Metal", "Support not compiled in this build")
192                            .with_operation("initialization")
193                            .with_suggestion(
194                                "Rebuild with --features metal to enable Metal support",
195                            )
196                            .with_suggestion("Metal backend requires macOS/iOS with Apple Silicon"),
197                    );
198                }
199            },
200            AccelerationBackend::Cpu => {
201                // CPU backend is always available
202            },
203        }
204
205        Ok(Self {
206            active_backend: backend,
207            config,
208            stats: AccelerationStats::default(),
209        })
210    }
211
212    /// Create CPU fallback accelerator
213    fn new_cpu_fallback() -> Self {
214        Self {
215            active_backend: AccelerationBackend::Cpu,
216            config: AccelerationConfig::default(),
217            stats: AccelerationStats::default(),
218        }
219    }
220
221    /// Select the best available backend
222    fn select_backend(config: &AccelerationConfig) -> Result<AccelerationBackend> {
223        // Try preferred backend first
224        if let Some(preferred) = config.preferred_backend {
225            if Self::is_backend_available(preferred) {
226                return Ok(preferred);
227            }
228        }
229
230        // Auto-select based on availability
231        let backends = [
232            AccelerationBackend::Cuda,
233            AccelerationBackend::Rocm,
234            AccelerationBackend::Intel,
235            AccelerationBackend::Vulkan,
236            AccelerationBackend::Metal,
237            AccelerationBackend::Cpu,
238        ];
239
240        for backend in backends {
241            if Self::is_backend_available(backend) {
242                return Ok(backend);
243            }
244        }
245
246        Err(
247            hardware_error("system", "No acceleration backend available on this system")
248                .with_operation("backend_selection")
249                .with_suggestion("Install GPU drivers (NVIDIA CUDA, AMD ROCm, Intel OneAPI)")
250                .with_suggestion("Ensure required features are enabled during compilation")
251                .with_suggestion("CPU backend should always be available as fallback"),
252        )
253    }
254
255    /// Check if backend is available
256    fn is_backend_available(backend: AccelerationBackend) -> bool {
257        match backend {
258            AccelerationBackend::Cuda => {
259                #[cfg(feature = "cuda")]
260                {
261                    crate::kernels::cuda_impl::api::init_cuda().is_ok()
262                }
263                #[cfg(not(feature = "cuda"))]
264                {
265                    false
266                }
267            },
268            AccelerationBackend::Rocm => {
269                #[cfg(feature = "rocm")]
270                {
271                    crate::kernels::rocm_impl::api::init_rocm().is_ok()
272                }
273                #[cfg(not(feature = "rocm"))]
274                {
275                    false
276                }
277            },
278            AccelerationBackend::Intel => {
279                #[cfg(feature = "intel")]
280                {
281                    crate::kernels::intel_impl::api::is_intel_available()
282                }
283                #[cfg(not(feature = "intel"))]
284                {
285                    false
286                }
287            },
288            AccelerationBackend::Vulkan => {
289                #[cfg(feature = "vulkan")]
290                {
291                    crate::kernels::vulkan_impl::VulkanImpl::new().is_ok()
292                }
293                #[cfg(not(feature = "vulkan"))]
294                {
295                    false
296                }
297            },
298            AccelerationBackend::Metal => {
299                // Check if Metal is available by attempting to create a Metal implementation
300                #[cfg(all(target_os = "macos", feature = "metal"))]
301                {
302                    crate::kernels::metal_impl::MetalImpl::new().is_ok()
303                }
304                #[cfg(not(all(target_os = "macos", feature = "metal")))]
305                {
306                    false
307                }
308            },
309            AccelerationBackend::Cpu => true,
310        }
311    }
312
313    /// Execute matrix multiplication with hardware acceleration
314    pub fn matmul(&mut self, a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
315        let start_time = std::time::Instant::now();
316
317        let result = match self.active_backend {
318            AccelerationBackend::Cuda => {
319                #[cfg(feature = "cuda")]
320                {
321                    crate::kernels::cuda_impl::api::cuda_matmul(a, b, c)
322                }
323                #[cfg(not(feature = "cuda"))]
324                {
325                    self.cpu_matmul(a, b, c)
326                }
327            },
328            AccelerationBackend::Rocm => {
329                #[cfg(feature = "rocm")]
330                {
331                    crate::kernels::rocm_impl::api::rocm_matmul(a, b, c)
332                }
333                #[cfg(not(feature = "rocm"))]
334                {
335                    self.cpu_matmul(a, b, c)
336                }
337            },
338            AccelerationBackend::Intel => {
339                #[cfg(feature = "intel")]
340                {
341                    crate::kernels::intel_impl::api::intel_matmul(a, b, c)
342                }
343                #[cfg(not(feature = "intel"))]
344                {
345                    self.cpu_matmul(a, b, c)
346                }
347            },
348            AccelerationBackend::Vulkan => {
349                #[cfg(feature = "vulkan")]
350                {
351                    let mut vulkan = crate::kernels::vulkan_impl::VulkanImpl::new()?;
352                    vulkan.matmul(a, b, c)
353                }
354                #[cfg(not(feature = "vulkan"))]
355                {
356                    self.cpu_matmul(a, b, c)
357                }
358            },
359            AccelerationBackend::Metal => {
360                #[cfg(all(target_os = "macos", feature = "metal"))]
361                {
362                    let metal_impl = crate::kernels::metal_impl::MetalImpl::new()?;
363                    metal_impl.matrix_multiply(a, b).and_then(|result| {
364                        // Copy result to output tensor c
365                        match (c, &result) {
366                            (Tensor::F32(c_arr), Tensor::F32(result_arr)) => {
367                                c_arr.assign(result_arr);
368                                Ok(())
369                            },
370                            _ => Err(tensor_op_error(
371                                "Tensor type mismatch in Metal matmul",
372                                "matmul",
373                            )),
374                        }
375                    })
376                }
377                #[cfg(not(all(target_os = "macos", feature = "metal")))]
378                {
379                    self.cpu_matmul(a, b, c)
380                }
381            },
382            AccelerationBackend::Cpu => self.cpu_matmul(a, b, c),
383        };
384
385        // Update statistics
386        self.stats.total_operations += 1;
387        self.stats.total_time_ms += start_time.elapsed().as_millis() as f64;
388
389        result
390    }
391
392    /// Execute Flash Attention with hardware acceleration
393    pub fn flash_attention(
394        &mut self,
395        query: &Tensor,
396        key: &Tensor,
397        value: &Tensor,
398        output: &mut Tensor,
399    ) -> Result<()> {
400        let start_time = std::time::Instant::now();
401
402        let result = match self.active_backend {
403            AccelerationBackend::Cuda => {
404                #[cfg(feature = "cuda")]
405                {
406                    crate::kernels::cuda_impl::api::cuda_flash_attention(query, key, value, output)
407                }
408                #[cfg(not(feature = "cuda"))]
409                {
410                    self.cpu_flash_attention(query, key, value, output)
411                }
412            },
413            AccelerationBackend::Rocm => {
414                #[cfg(feature = "rocm")]
415                {
416                    crate::kernels::rocm_impl::api::rocm_flash_attention(query, key, value, output)
417                }
418                #[cfg(not(feature = "rocm"))]
419                {
420                    self.cpu_flash_attention(query, key, value, output)
421                }
422            },
423            AccelerationBackend::Intel => {
424                #[cfg(feature = "intel")]
425                {
426                    crate::kernels::intel_impl::api::intel_flash_attention(
427                        query, key, value, output,
428                    )
429                }
430                #[cfg(not(feature = "intel"))]
431                {
432                    self.cpu_flash_attention(query, key, value, output)
433                }
434            },
435            AccelerationBackend::Vulkan => {
436                #[cfg(feature = "vulkan")]
437                {
438                    let mut vulkan = crate::kernels::vulkan_impl::VulkanImpl::new()?;
439                    let scale = 1.0 / (query.shape()[2] as f32).sqrt();
440                    vulkan.flash_attention(query, key, value, output, scale)
441                }
442                #[cfg(not(feature = "vulkan"))]
443                {
444                    self.cpu_flash_attention(query, key, value, output)
445                }
446            },
447            AccelerationBackend::Metal => {
448                #[cfg(all(target_os = "macos", feature = "metal"))]
449                {
450                    let metal_impl = crate::kernels::metal_impl::MetalImpl::new()?;
451                    metal_impl.flash_attention(query, key, value, output)
452                }
453                #[cfg(not(all(target_os = "macos", feature = "metal")))]
454                {
455                    self.cpu_flash_attention(query, key, value, output)
456                }
457            },
458            AccelerationBackend::Cpu => self.cpu_flash_attention(query, key, value, output),
459        };
460
461        // Update statistics
462        self.stats.total_operations += 1;
463        self.stats.total_time_ms += start_time.elapsed().as_millis() as f64;
464
465        result
466    }
467
468    /// CPU fallback for matrix multiplication
469    fn cpu_matmul(&self, a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
470        // Use the tensor's built-in matmul implementation
471        let result = a.matmul(b)?;
472        *c = result;
473        Ok(())
474    }
475
476    /// CPU fallback for Flash Attention
477    fn cpu_flash_attention(
478        &self,
479        query: &Tensor,
480        key: &Tensor,
481        value: &Tensor,
482        output: &mut Tensor,
483    ) -> Result<()> {
484        // Simplified CPU implementation of Flash Attention
485        let q_shape = query.shape();
486        let batch_size = q_shape[0];
487        let seq_len = q_shape[1];
488        let head_dim = q_shape[2];
489
490        // Compute attention scores: Q @ K^T
491        let key_transposed = key.transpose(1, 2)?;
492        let scores = query.matmul(&key_transposed)?;
493
494        // Apply scaling
495        let scale = 1.0 / (head_dim as f32).sqrt();
496        let scaled_scores = scores.mul_scalar(scale)?;
497
498        // Apply softmax
499        let attention_weights = scaled_scores.softmax(2)?;
500
501        // Apply attention to values: attention_weights @ V
502        let result = attention_weights.matmul(value)?;
503
504        *output = result;
505        Ok(())
506    }
507
508    /// Get active backend
509    pub fn active_backend(&self) -> AccelerationBackend {
510        self.active_backend
511    }
512
513    /// Get performance statistics
514    pub fn get_stats(&self) -> &AccelerationStats {
515        &self.stats
516    }
517
518    /// Reset performance statistics
519    pub fn reset_stats(&mut self) {
520        self.stats = AccelerationStats::default();
521    }
522
523    /// Get device information
524    pub fn device_info(&self) -> Result<String> {
525        match self.active_backend {
526            AccelerationBackend::Cuda => {
527                #[cfg(feature = "cuda")]
528                {
529                    crate::kernels::cuda_impl::api::cuda_device_info()
530                }
531                #[cfg(not(feature = "cuda"))]
532                {
533                    Ok("CUDA not available".to_string())
534                }
535            },
536            AccelerationBackend::Rocm => {
537                #[cfg(feature = "rocm")]
538                {
539                    crate::kernels::rocm_impl::api::rocm_device_info()
540                }
541                #[cfg(not(feature = "rocm"))]
542                {
543                    Ok("ROCm not available".to_string())
544                }
545            },
546            AccelerationBackend::Intel => {
547                #[cfg(feature = "intel")]
548                {
549                    crate::kernels::intel_impl::api::intel_device_info()
550                }
551                #[cfg(not(feature = "intel"))]
552                {
553                    Ok("Intel OneAPI not available".to_string())
554                }
555            },
556            AccelerationBackend::Cpu => Ok(format!("CPU: {} cores", num_cpus::get())),
557            _ => Ok(format!("Backend: {:?}", self.active_backend)),
558        }
559    }
560
561    /// Get memory statistics
562    pub fn memory_stats(&self) -> Result<(usize, usize)> {
563        match self.active_backend {
564            AccelerationBackend::Cuda => {
565                #[cfg(feature = "cuda")]
566                {
567                    crate::kernels::cuda_impl::api::cuda_memory_stats()
568                }
569                #[cfg(not(feature = "cuda"))]
570                {
571                    Ok((0, 0))
572                }
573            },
574            AccelerationBackend::Rocm => {
575                #[cfg(feature = "rocm")]
576                {
577                    crate::kernels::rocm_impl::api::rocm_memory_stats()
578                }
579                #[cfg(not(feature = "rocm"))]
580                {
581                    Ok((0, 0))
582                }
583            },
584            AccelerationBackend::Intel => {
585                #[cfg(feature = "intel")]
586                {
587                    crate::kernels::intel_impl::api::intel_memory_stats()
588                }
589                #[cfg(not(feature = "intel"))]
590                {
591                    Ok((0, 0))
592                }
593            },
594            AccelerationBackend::Cpu => {
595                Ok((0, 0)) // CPU doesn't have dedicated memory pool
596            },
597            _ => Ok((0, 0)),
598        }
599    }
600}
601
602/// Public API for hardware acceleration
603pub mod api {
604    use super::*;
605
606    /// Initialize hardware acceleration with default configuration
607    pub fn init_hardware_acceleration() -> Result<()> {
608        HardwareAccelerator::initialize(AccelerationConfig::default())?;
609        Ok(())
610    }
611
612    /// Initialize hardware acceleration with custom configuration
613    pub fn init_hardware_acceleration_with_config(config: AccelerationConfig) -> Result<()> {
614        HardwareAccelerator::initialize(config)?;
615        Ok(())
616    }
617
618    /// Execute accelerated matrix multiplication
619    pub fn accelerated_matmul(a: &Tensor, b: &Tensor, c: &mut Tensor) -> Result<()> {
620        let accelerator = HardwareAccelerator::global()?;
621
622        // Since we can't get a mutable reference from the static,
623        // we need to handle this differently for now
624        let result = a.matmul(b)?;
625        *c = result;
626        Ok(())
627    }
628
629    /// Execute accelerated Flash Attention
630    pub fn accelerated_flash_attention(
631        query: &Tensor,
632        key: &Tensor,
633        value: &Tensor,
634        output: &mut Tensor,
635    ) -> Result<()> {
636        let accelerator = HardwareAccelerator::global()?;
637
638        // Since we can't get a mutable reference from the static,
639        // we need to handle this differently for now
640        // Simplified CPU implementation of Flash Attention
641        let q_shape = query.shape();
642        let head_dim = q_shape[q_shape.len() - 1];
643
644        // Compute attention scores: Q @ K^T
645        let key_transposed = key.transpose(q_shape.len() - 2, q_shape.len() - 1)?;
646        let scores = query.matmul(&key_transposed)?;
647
648        // Apply scaling
649        let scale = 1.0 / (head_dim as f32).sqrt();
650        let scaled_scores = scores.mul_scalar(scale)?;
651
652        // Apply softmax
653        let attention_weights = scaled_scores.softmax((q_shape.len() - 1) as i32)?;
654
655        // Apply attention to values: attention_weights @ V
656        let result = attention_weights.matmul(value)?;
657
658        *output = result;
659        Ok(())
660    }
661
662    /// Get active acceleration backend
663    pub fn get_active_backend() -> Result<AccelerationBackend> {
664        Ok(HardwareAccelerator::global()?.active_backend())
665    }
666
667    /// Get device information
668    pub fn get_device_info() -> Result<String> {
669        HardwareAccelerator::global()?.device_info()
670    }
671
672    /// Get performance statistics
673    pub fn get_performance_stats() -> Result<AccelerationStats> {
674        Ok(HardwareAccelerator::global()?.get_stats().clone())
675    }
676
677    /// Get memory statistics
678    pub fn get_memory_stats() -> Result<(usize, usize)> {
679        HardwareAccelerator::global()?.memory_stats()
680    }
681
682    /// Check if a specific backend is available
683    pub fn is_backend_available(backend: AccelerationBackend) -> bool {
684        HardwareAccelerator::is_backend_available(backend)
685    }
686
687    /// List all available backends
688    pub fn list_available_backends() -> Vec<AccelerationBackend> {
689        let all_backends = [
690            AccelerationBackend::Cuda,
691            AccelerationBackend::Rocm,
692            AccelerationBackend::Intel,
693            AccelerationBackend::Vulkan,
694            AccelerationBackend::Metal,
695            AccelerationBackend::Cpu,
696        ];
697
698        all_backends
699            .into_iter()
700            .filter(|&backend| HardwareAccelerator::is_backend_available(backend))
701            .collect()
702    }
703}
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708    use crate::tensor::Tensor;
709
710    #[test]
711    fn test_hardware_acceleration_initialization() {
712        let config = AccelerationConfig::default();
713        assert!(HardwareAccelerator::initialize(config).is_ok());
714    }
715
716    #[test]
717    fn test_backend_selection() {
718        let available = api::list_available_backends();
719        assert!(!available.is_empty());
720        assert!(available.contains(&AccelerationBackend::Cpu));
721    }
722
723    #[test]
724    fn test_accelerated_matmul() {
725        let _ = api::init_hardware_acceleration();
726
727        let a = Tensor::ones(&[4, 4]).expect("Failed to create ones tensor");
728        let b = Tensor::ones(&[4, 4]).expect("Failed to create ones tensor");
729        let mut c = Tensor::zeros(&[4, 4]).expect("Failed to create zero tensor");
730
731        let result = api::accelerated_matmul(&a, &b, &mut c);
732        assert!(result.is_ok());
733
734        // Result should be all 4s
735        let data = c.data().expect("operation failed in test");
736        assert!(data.iter().all(|&x| (x - 4.0).abs() < 1e-6));
737    }
738
739    #[test]
740    fn test_device_info() {
741        let _ = api::init_hardware_acceleration();
742        let info = api::get_device_info();
743        assert!(info.is_ok());
744    }
745
746    #[test]
747    fn test_performance_stats() {
748        let _ = api::init_hardware_acceleration();
749        let stats = api::get_performance_stats();
750        assert!(stats.is_ok());
751    }
752
753    #[test]
754    fn test_memory_stats() {
755        let _ = api::init_hardware_acceleration();
756        let stats = api::get_memory_stats();
757        assert!(stats.is_ok());
758    }
759}