Skip to main content

sklears_simd/
custom_accelerator.rs

1//! Custom accelerator support framework
2//!
3//! This module provides a flexible framework for integrating custom hardware
4//! accelerators with the SIMD operations, including ASICs, custom chips, and
5//! specialized processing units.
6
7use crate::traits::SimdError;
8
9#[cfg(feature = "no-std")]
10use core::any::Any;
11#[cfg(not(feature = "no-std"))]
12use std::any::Any;
13
14#[cfg(feature = "no-std")]
15use alloc::{
16    boxed::Box,
17    collections::BTreeMap as HashMap,
18    string::{String, ToString},
19    vec,
20    vec::Vec,
21};
22#[cfg(not(feature = "no-std"))]
23use std::{collections::HashMap, string::ToString};
24
25#[cfg(feature = "no-std")]
26extern crate alloc;
27
28/// Accelerator types
29#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30pub enum AcceleratorType {
31    ASIC,
32    Custom,
33    DSP,
34    VPU,    // Vision Processing Unit
35    NPU,    // Neural Processing Unit
36    DPU,    // Deep Learning Processing Unit
37    AI,     // Generic AI accelerator
38    Crypto, // Cryptographic accelerator
39    Signal, // Signal processing accelerator
40    Matrix, // Matrix processing accelerator
41}
42
43/// Accelerator capabilities
44#[derive(Debug, Clone)]
45pub struct AcceleratorCapabilities {
46    pub supported_operations: Vec<AcceleratorOperation>,
47    pub data_types: Vec<AcceleratorDataType>,
48    pub max_batch_size: usize,
49    pub memory_mb: u64,
50    pub compute_units: u32,
51    pub peak_performance_ops: u64,
52    pub power_consumption_w: f64,
53    pub precision_modes: Vec<AcceleratorPrecision>,
54}
55
56/// Accelerator operations
57#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum AcceleratorOperation {
59    MatrixMultiply,
60    Convolution,
61    Activation,
62    Pooling,
63    Normalization,
64    Attention,
65    Embedding,
66    Reduction,
67    Transform,
68    Sort,
69    Search,
70    Compress,
71    Encrypt,
72    Hash,
73    Custom(u32),
74}
75
76/// Accelerator data types
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum AcceleratorDataType {
79    F32,
80    F16,
81    BF16,
82    I32,
83    I16,
84    I8,
85    U32,
86    U16,
87    U8,
88    Bool,
89    Custom(u32),
90}
91
92/// Accelerator precision modes
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum AcceleratorPrecision {
95    High,
96    Medium,
97    Low,
98    Mixed,
99    Adaptive,
100}
101
102/// Accelerator device information
103#[derive(Debug, Clone)]
104pub struct AcceleratorDevice {
105    pub id: u32,
106    pub name: String,
107    pub vendor: String,
108    pub model: String,
109    pub accelerator_type: AcceleratorType,
110    pub capabilities: AcceleratorCapabilities,
111    pub driver_version: String,
112    pub firmware_version: String,
113    pub pci_id: Option<String>,
114    pub numa_node: Option<u32>,
115}
116
117/// Accelerator buffer
118#[derive(Debug)]
119pub struct AcceleratorBuffer<T> {
120    pub ptr: *mut T,
121    pub size: usize,
122    pub device: AcceleratorDevice,
123    pub alignment: usize,
124    pub memory_type: AcceleratorMemoryType,
125    #[allow(dead_code)] // Reserved for native accelerator backend handle (CUDA/OpenCL/custom)
126    backend_handle: Option<Box<dyn Any + Send + Sync>>,
127}
128
129/// Accelerator memory types
130#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum AcceleratorMemoryType {
132    Device,
133    Host,
134    Unified,
135    Pinned,
136    Cached,
137    Uncached,
138}
139
140unsafe impl<T: Send> Send for AcceleratorBuffer<T> {}
141unsafe impl<T: Sync> Sync for AcceleratorBuffer<T> {}
142
143impl<T> Drop for AcceleratorBuffer<T> {
144    fn drop(&mut self) {
145        // Free accelerator memory when buffer is dropped
146    }
147}
148
149/// Accelerator context
150pub struct AcceleratorContext {
151    pub device: AcceleratorDevice,
152    pub command_queues: Vec<AcceleratorQueue>,
153    pub memory_pools: HashMap<AcceleratorMemoryType, AcceleratorMemoryPool>,
154    #[allow(dead_code)] // Reserved for native accelerator backend context (CUDA/OpenCL/custom)
155    backend_context: Option<Box<dyn Any + Send + Sync>>,
156}
157
158/// Accelerator command queue
159#[derive(Debug)]
160pub struct AcceleratorQueue {
161    pub id: u32,
162    pub priority: AcceleratorPriority,
163    pub device_id: u32,
164    #[allow(dead_code)] // Reserved for native command queue handle (CUDA stream / OpenCL queue)
165    backend_queue: Option<Box<dyn Any + Send + Sync>>,
166}
167
168/// Accelerator priority levels
169#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum AcceleratorPriority {
171    Low,
172    Normal,
173    High,
174    Critical,
175}
176
177/// Accelerator memory pool
178#[derive(Debug)]
179pub struct AcceleratorMemoryPool {
180    pub total_size: usize,
181    pub available_size: usize,
182    pub allocation_count: usize,
183    pub memory_type: AcceleratorMemoryType,
184    #[allow(dead_code)] // Reserved for native memory pool handle (CUDA/OpenCL/custom allocator)
185    backend_pool: Option<Box<dyn Any + Send + Sync>>,
186}
187
188/// Accelerator kernel configuration
189#[derive(Debug, Clone)]
190pub struct AcceleratorKernel {
191    pub name: String,
192    pub operation: AcceleratorOperation,
193    pub input_buffers: Vec<u32>,
194    pub output_buffers: Vec<u32>,
195    pub parameters: HashMap<String, AcceleratorParameter>,
196    pub work_size: (usize, usize, usize),
197    pub local_size: (usize, usize, usize),
198}
199
200/// Accelerator parameter
201#[derive(Debug, Clone)]
202pub enum AcceleratorParameter {
203    Int(i64),
204    Float(f64),
205    Bool(bool),
206    String(String),
207    Array(Vec<u8>),
208}
209
210/// Accelerator operations interface
211pub trait AcceleratorOperations {
212    /// Allocate accelerator memory
213    fn allocate<T>(
214        &self,
215        size: usize,
216        memory_type: AcceleratorMemoryType,
217    ) -> Result<AcceleratorBuffer<T>, SimdError>;
218
219    /// Copy data to accelerator
220    fn copy_to_accelerator<T>(
221        &self,
222        host_data: &[T],
223        accel_buffer: &mut AcceleratorBuffer<T>,
224        queue: Option<&AcceleratorQueue>,
225    ) -> Result<(), SimdError>;
226
227    /// Copy data from accelerator
228    fn copy_from_accelerator<T>(
229        &self,
230        accel_buffer: &AcceleratorBuffer<T>,
231        host_data: &mut [T],
232        queue: Option<&AcceleratorQueue>,
233    ) -> Result<(), SimdError>;
234
235    /// Execute kernel on accelerator
236    fn execute_kernel(
237        &self,
238        kernel: &AcceleratorKernel,
239        buffers: &[&AcceleratorBuffer<u8>],
240        queue: Option<&AcceleratorQueue>,
241    ) -> Result<(), SimdError>;
242
243    /// Synchronize accelerator operations
244    fn synchronize(&self, queue: Option<&AcceleratorQueue>) -> Result<(), SimdError>;
245
246    /// Get accelerator status
247    fn get_status(&self) -> Result<AcceleratorStatus, SimdError>;
248}
249
250/// Accelerator status
251#[derive(Debug, Clone)]
252pub struct AcceleratorStatus {
253    pub utilization_percent: f64,
254    pub memory_usage_percent: f64,
255    pub temperature_c: f64,
256    pub power_consumption_w: f64,
257    pub clock_frequency_mhz: f64,
258    pub active_operations: Vec<AcceleratorOperation>,
259    pub error_count: u32,
260}
261
262/// Accelerator runtime
263pub struct AcceleratorRuntime {
264    devices: Vec<AcceleratorDevice>,
265    contexts: Vec<AcceleratorContext>,
266    drivers: HashMap<AcceleratorType, Box<dyn AcceleratorDriver>>,
267}
268
269/// Accelerator driver interface
270pub trait AcceleratorDriver: Send + Sync {
271    fn initialize(&self) -> Result<(), SimdError>;
272    fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError>;
273    fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError>;
274    fn is_available(&self) -> bool;
275}
276
277impl AcceleratorRuntime {
278    /// Create new accelerator runtime
279    pub fn new() -> Result<Self, SimdError> {
280        let mut runtime = Self {
281            devices: Vec::new(),
282            contexts: Vec::new(),
283            drivers: HashMap::new(),
284        };
285
286        // Register built-in drivers
287        runtime.register_driver(AcceleratorType::ASIC, Box::new(AsicDriver::new()));
288        runtime.register_driver(AcceleratorType::DSP, Box::new(DspDriver::new()));
289        runtime.register_driver(AcceleratorType::VPU, Box::new(VpuDriver::new()));
290        runtime.register_driver(AcceleratorType::NPU, Box::new(NpuDriver::new()));
291
292        // Discover devices
293        runtime.discover_all_devices()?;
294
295        Ok(runtime)
296    }
297
298    /// Register accelerator driver
299    pub fn register_driver(
300        &mut self,
301        accel_type: AcceleratorType,
302        driver: Box<dyn AcceleratorDriver>,
303    ) {
304        self.drivers.insert(accel_type, driver);
305    }
306
307    /// Discover all devices
308    fn discover_all_devices(&mut self) -> Result<(), SimdError> {
309        for driver in self.drivers.values() {
310            if driver.is_available() {
311                let devices = driver.discover_devices()?;
312                self.devices.extend(devices);
313            }
314        }
315        Ok(())
316    }
317
318    /// Get available devices
319    pub fn devices(&self) -> &[AcceleratorDevice] {
320        &self.devices
321    }
322
323    /// Get devices by type
324    pub fn devices_by_type(&self, accel_type: AcceleratorType) -> Vec<&AcceleratorDevice> {
325        self.devices
326            .iter()
327            .filter(|d| d.accelerator_type == accel_type)
328            .collect()
329    }
330
331    /// Create context for device
332    pub fn create_context(&mut self, device_id: u32) -> Result<&AcceleratorContext, SimdError> {
333        let device = self.devices.get(device_id as usize).ok_or_else(|| {
334            SimdError::InvalidArgument("Invalid accelerator device ID".to_string())
335        })?;
336
337        let driver = self
338            .drivers
339            .get(&device.accelerator_type)
340            .ok_or_else(|| SimdError::NotImplemented("Driver not available".to_string()))?;
341
342        let context = driver.create_context(device)?;
343        self.contexts.push(context);
344        Ok(self
345            .contexts
346            .last()
347            .expect("collection should not be empty"))
348    }
349
350    /// Get best device for operation
351    pub fn get_best_device(&self, operation: AcceleratorOperation) -> Option<&AcceleratorDevice> {
352        self.devices
353            .iter()
354            .filter(|d| d.capabilities.supported_operations.contains(&operation))
355            .max_by(|a, b| {
356                a.capabilities
357                    .peak_performance_ops
358                    .cmp(&b.capabilities.peak_performance_ops)
359            })
360    }
361}
362
363/// Built-in accelerator drivers
364struct AsicDriver;
365struct DspDriver;
366struct VpuDriver;
367struct NpuDriver;
368
369impl AsicDriver {
370    fn new() -> Self {
371        Self
372    }
373}
374
375impl AcceleratorDriver for AsicDriver {
376    fn initialize(&self) -> Result<(), SimdError> {
377        Ok(())
378    }
379
380    fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
381        Ok(vec![])
382    }
383
384    fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
385        Ok(AcceleratorContext {
386            device: device.clone(),
387            command_queues: vec![],
388            memory_pools: HashMap::new(),
389            backend_context: None,
390        })
391    }
392
393    fn is_available(&self) -> bool {
394        false
395    }
396}
397
398impl DspDriver {
399    fn new() -> Self {
400        Self
401    }
402}
403
404impl AcceleratorDriver for DspDriver {
405    fn initialize(&self) -> Result<(), SimdError> {
406        Ok(())
407    }
408
409    fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
410        Ok(vec![])
411    }
412
413    fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
414        Ok(AcceleratorContext {
415            device: device.clone(),
416            command_queues: vec![],
417            memory_pools: HashMap::new(),
418            backend_context: None,
419        })
420    }
421
422    fn is_available(&self) -> bool {
423        false
424    }
425}
426
427impl VpuDriver {
428    fn new() -> Self {
429        Self
430    }
431}
432
433impl AcceleratorDriver for VpuDriver {
434    fn initialize(&self) -> Result<(), SimdError> {
435        Ok(())
436    }
437
438    fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
439        Ok(vec![])
440    }
441
442    fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
443        Ok(AcceleratorContext {
444            device: device.clone(),
445            command_queues: vec![],
446            memory_pools: HashMap::new(),
447            backend_context: None,
448        })
449    }
450
451    fn is_available(&self) -> bool {
452        false
453    }
454}
455
456impl NpuDriver {
457    fn new() -> Self {
458        Self
459    }
460}
461
462impl AcceleratorDriver for NpuDriver {
463    fn initialize(&self) -> Result<(), SimdError> {
464        Ok(())
465    }
466
467    fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
468        Ok(vec![])
469    }
470
471    fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
472        Ok(AcceleratorContext {
473            device: device.clone(),
474            command_queues: vec![],
475            memory_pools: HashMap::new(),
476            backend_context: None,
477        })
478    }
479
480    fn is_available(&self) -> bool {
481        false
482    }
483}
484
485/// Accelerator optimization utilities
486pub mod optimization {
487    use super::*;
488
489    /// Optimize kernel for specific accelerator
490    pub fn optimize_kernel(
491        kernel: &AcceleratorKernel,
492        device: &AcceleratorDevice,
493    ) -> Result<AcceleratorKernel, SimdError> {
494        let mut optimized = kernel.clone();
495
496        // Optimize based on device capabilities
497        match device.accelerator_type {
498            AcceleratorType::NPU => {
499                // Optimize for neural processing
500                optimized.work_size = optimize_for_npu(kernel.work_size, &device.capabilities);
501            }
502            AcceleratorType::VPU => {
503                // Optimize for vision processing
504                optimized.work_size = optimize_for_vpu(kernel.work_size, &device.capabilities);
505            }
506            AcceleratorType::DSP => {
507                // Optimize for signal processing
508                optimized.work_size = optimize_for_dsp(kernel.work_size, &device.capabilities);
509            }
510            _ => {
511                // Generic optimization
512                optimized.work_size = optimize_generic(kernel.work_size, &device.capabilities);
513            }
514        }
515
516        Ok(optimized)
517    }
518
519    fn optimize_for_npu(
520        work_size: (usize, usize, usize),
521        caps: &AcceleratorCapabilities,
522    ) -> (usize, usize, usize) {
523        let optimal_batch = caps.max_batch_size.min(work_size.0);
524        (optimal_batch, work_size.1, work_size.2)
525    }
526
527    fn optimize_for_vpu(
528        work_size: (usize, usize, usize),
529        caps: &AcceleratorCapabilities,
530    ) -> (usize, usize, usize) {
531        let compute_units = caps.compute_units as usize;
532        let optimal_x = work_size.0.div_ceil(compute_units) * compute_units;
533        (optimal_x, work_size.1, work_size.2)
534    }
535
536    fn optimize_for_dsp(
537        work_size: (usize, usize, usize),
538        _caps: &AcceleratorCapabilities,
539    ) -> (usize, usize, usize) {
540        // DSP typically works well with power-of-2 sizes
541        let next_power_of_2 = |n: usize| {
542            if n == 0 {
543                1
544            } else {
545                1 << (64 - (n - 1).leading_zeros())
546            }
547        };
548
549        (next_power_of_2(work_size.0), work_size.1, work_size.2)
550    }
551
552    fn optimize_generic(
553        work_size: (usize, usize, usize),
554        caps: &AcceleratorCapabilities,
555    ) -> (usize, usize, usize) {
556        let compute_units = caps.compute_units as usize;
557        let optimal_size = work_size.0.div_ceil(compute_units) * compute_units;
558        (optimal_size, work_size.1, work_size.2)
559    }
560
561    /// Select best accelerator for operation
562    pub fn select_accelerator(
563        operation: AcceleratorOperation,
564        data_size: usize,
565        devices: &[AcceleratorDevice],
566    ) -> Option<&AcceleratorDevice> {
567        devices
568            .iter()
569            .filter(|d| d.capabilities.supported_operations.contains(&operation))
570            .filter(|d| d.capabilities.max_batch_size >= data_size)
571            .max_by(|a, b| {
572                let score_a = compute_device_score(a, operation, data_size);
573                let score_b = compute_device_score(b, operation, data_size);
574                score_a
575                    .partial_cmp(&score_b)
576                    .unwrap_or(core::cmp::Ordering::Equal)
577            })
578    }
579
580    fn compute_device_score(
581        device: &AcceleratorDevice,
582        operation: AcceleratorOperation,
583        data_size: usize,
584    ) -> f64 {
585        let mut score = 0.0;
586
587        // Performance score
588        score += device.capabilities.peak_performance_ops as f64 / 1e9;
589
590        // Memory score
591        let memory_ratio =
592            data_size as f64 / (device.capabilities.memory_mb as f64 * 1024.0 * 1024.0);
593        score += if memory_ratio <= 1.0 {
594            1.0
595        } else {
596            1.0 / memory_ratio
597        };
598
599        // Power efficiency score
600        let ops_per_watt = device.capabilities.peak_performance_ops as f64
601            / device.capabilities.power_consumption_w;
602        score += ops_per_watt / 1e9;
603
604        // Operation-specific bonuses
605        match (operation, device.accelerator_type) {
606            (AcceleratorOperation::MatrixMultiply, AcceleratorType::NPU) => score += 2.0,
607            (AcceleratorOperation::Convolution, AcceleratorType::VPU) => score += 2.0,
608            (AcceleratorOperation::Transform, AcceleratorType::DSP) => score += 2.0,
609            _ => {}
610        }
611
612        score
613    }
614}
615
616#[allow(non_snake_case)]
617#[cfg(all(test, not(feature = "no-std")))]
618mod tests {
619    use super::*;
620
621    #[cfg(feature = "no-std")]
622    use alloc::{
623        string::{String, ToString},
624        vec,
625        vec::Vec,
626    };
627
628    #[test]
629    fn test_accelerator_runtime_creation() {
630        let runtime = AcceleratorRuntime::new();
631        assert!(runtime.is_ok());
632    }
633
634    #[test]
635    fn test_accelerator_type_display() {
636        let types = vec![
637            AcceleratorType::ASIC,
638            AcceleratorType::Custom,
639            AcceleratorType::DSP,
640            AcceleratorType::VPU,
641            AcceleratorType::NPU,
642        ];
643
644        for accel_type in types {
645            println!("Accelerator type: {:?}", accel_type);
646        }
647    }
648
649    #[test]
650    fn test_accelerator_capabilities() {
651        let caps = AcceleratorCapabilities {
652            supported_operations: vec![
653                AcceleratorOperation::MatrixMultiply,
654                AcceleratorOperation::Convolution,
655            ],
656            data_types: vec![AcceleratorDataType::F32, AcceleratorDataType::F16],
657            max_batch_size: 1024,
658            memory_mb: 8192,
659            compute_units: 64,
660            peak_performance_ops: 1000000000,
661            power_consumption_w: 150.0,
662            precision_modes: vec![AcceleratorPrecision::High, AcceleratorPrecision::Mixed],
663        };
664
665        assert_eq!(caps.supported_operations.len(), 2);
666        assert_eq!(caps.data_types.len(), 2);
667        assert_eq!(caps.max_batch_size, 1024);
668    }
669
670    #[test]
671    fn test_accelerator_kernel() {
672        let mut params = HashMap::new();
673        params.insert("alpha".to_string(), AcceleratorParameter::Float(1.0));
674        params.insert("beta".to_string(), AcceleratorParameter::Float(0.0));
675
676        let kernel = AcceleratorKernel {
677            name: "test_kernel".to_string(),
678            operation: AcceleratorOperation::MatrixMultiply,
679            input_buffers: vec![0, 1],
680            output_buffers: vec![2],
681            parameters: params,
682            work_size: (1024, 1024, 1),
683            local_size: (16, 16, 1),
684        };
685
686        assert_eq!(kernel.name, "test_kernel");
687        assert_eq!(kernel.operation, AcceleratorOperation::MatrixMultiply);
688        assert_eq!(kernel.input_buffers.len(), 2);
689        assert_eq!(kernel.output_buffers.len(), 1);
690        assert_eq!(kernel.parameters.len(), 2);
691    }
692
693    #[test]
694    fn test_accelerator_device() {
695        let device = AcceleratorDevice {
696            id: 0,
697            name: "Test Accelerator".to_string(),
698            vendor: "Test Vendor".to_string(),
699            model: "Test Model".to_string(),
700            accelerator_type: AcceleratorType::NPU,
701            capabilities: AcceleratorCapabilities {
702                supported_operations: vec![AcceleratorOperation::MatrixMultiply],
703                data_types: vec![AcceleratorDataType::F32],
704                max_batch_size: 512,
705                memory_mb: 4096,
706                compute_units: 32,
707                peak_performance_ops: 500000000,
708                power_consumption_w: 100.0,
709                precision_modes: vec![AcceleratorPrecision::High],
710            },
711            driver_version: "1.0.0".to_string(),
712            firmware_version: "2.0.0".to_string(),
713            pci_id: Some("1234:5678".to_string()),
714            numa_node: Some(0),
715        };
716
717        assert_eq!(device.id, 0);
718        assert_eq!(device.accelerator_type, AcceleratorType::NPU);
719        assert_eq!(device.capabilities.compute_units, 32);
720    }
721
722    #[test]
723    fn test_kernel_optimization() {
724        let device = AcceleratorDevice {
725            id: 0,
726            name: "Test NPU".to_string(),
727            vendor: "Test".to_string(),
728            model: "NPU-100".to_string(),
729            accelerator_type: AcceleratorType::NPU,
730            capabilities: AcceleratorCapabilities {
731                supported_operations: vec![AcceleratorOperation::MatrixMultiply],
732                data_types: vec![AcceleratorDataType::F32],
733                max_batch_size: 256,
734                memory_mb: 2048,
735                compute_units: 16,
736                peak_performance_ops: 100000000,
737                power_consumption_w: 50.0,
738                precision_modes: vec![AcceleratorPrecision::High],
739            },
740            driver_version: "1.0.0".to_string(),
741            firmware_version: "1.0.0".to_string(),
742            pci_id: None,
743            numa_node: None,
744        };
745
746        let kernel = AcceleratorKernel {
747            name: "test".to_string(),
748            operation: AcceleratorOperation::MatrixMultiply,
749            input_buffers: vec![0, 1],
750            output_buffers: vec![2],
751            parameters: HashMap::new(),
752            work_size: (512, 512, 1),
753            local_size: (16, 16, 1),
754        };
755
756        let optimized = optimization::optimize_kernel(&kernel, &device);
757        assert!(optimized.is_ok());
758
759        let opt_kernel = optimized.expect("operation should succeed");
760        assert_eq!(opt_kernel.work_size.0, 256); // Limited by max_batch_size
761    }
762
763    #[test]
764    fn test_accelerator_selection() {
765        let devices = vec![
766            AcceleratorDevice {
767                id: 0,
768                name: "NPU".to_string(),
769                vendor: "Test".to_string(),
770                model: "NPU-100".to_string(),
771                accelerator_type: AcceleratorType::NPU,
772                capabilities: AcceleratorCapabilities {
773                    supported_operations: vec![AcceleratorOperation::MatrixMultiply],
774                    data_types: vec![AcceleratorDataType::F32],
775                    max_batch_size: 1024,
776                    memory_mb: 4096,
777                    compute_units: 32,
778                    peak_performance_ops: 1000000000,
779                    power_consumption_w: 100.0,
780                    precision_modes: vec![AcceleratorPrecision::High],
781                },
782                driver_version: "1.0.0".to_string(),
783                firmware_version: "1.0.0".to_string(),
784                pci_id: None,
785                numa_node: None,
786            },
787            AcceleratorDevice {
788                id: 1,
789                name: "VPU".to_string(),
790                vendor: "Test".to_string(),
791                model: "VPU-200".to_string(),
792                accelerator_type: AcceleratorType::VPU,
793                capabilities: AcceleratorCapabilities {
794                    supported_operations: vec![AcceleratorOperation::Convolution],
795                    data_types: vec![AcceleratorDataType::F16],
796                    max_batch_size: 512,
797                    memory_mb: 2048,
798                    compute_units: 16,
799                    peak_performance_ops: 500000000,
800                    power_consumption_w: 50.0,
801                    precision_modes: vec![AcceleratorPrecision::Mixed],
802                },
803                driver_version: "1.0.0".to_string(),
804                firmware_version: "1.0.0".to_string(),
805                pci_id: None,
806                numa_node: None,
807            },
808        ];
809
810        let selected =
811            optimization::select_accelerator(AcceleratorOperation::MatrixMultiply, 512, &devices);
812        assert!(selected.is_some());
813        assert_eq!(selected.expect("operation should succeed").id, 0);
814
815        let selected =
816            optimization::select_accelerator(AcceleratorOperation::Convolution, 256, &devices);
817        assert!(selected.is_some());
818        assert_eq!(selected.expect("operation should succeed").id, 1);
819    }
820}