1use crate::traits::SimdError;
8
9#[cfg(feature = "no-std")]
10use core::any::Any;
11#[cfg(not(feature = "no-std"))]
12use std::any::Any;
13
14#[cfg(feature = "no-std")]
15use alloc::{
16 boxed::Box,
17 collections::BTreeMap as HashMap,
18 string::{String, ToString},
19 vec,
20 vec::Vec,
21};
22#[cfg(not(feature = "no-std"))]
23use std::{collections::HashMap, string::ToString};
24
25#[cfg(feature = "no-std")]
26extern crate alloc;
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
30pub enum AcceleratorType {
31 ASIC,
32 Custom,
33 DSP,
34 VPU, NPU, DPU, AI, Crypto, Signal, Matrix, }
42
43#[derive(Debug, Clone)]
45pub struct AcceleratorCapabilities {
46 pub supported_operations: Vec<AcceleratorOperation>,
47 pub data_types: Vec<AcceleratorDataType>,
48 pub max_batch_size: usize,
49 pub memory_mb: u64,
50 pub compute_units: u32,
51 pub peak_performance_ops: u64,
52 pub power_consumption_w: f64,
53 pub precision_modes: Vec<AcceleratorPrecision>,
54}
55
56#[derive(Debug, Clone, Copy, PartialEq, Eq)]
58pub enum AcceleratorOperation {
59 MatrixMultiply,
60 Convolution,
61 Activation,
62 Pooling,
63 Normalization,
64 Attention,
65 Embedding,
66 Reduction,
67 Transform,
68 Sort,
69 Search,
70 Compress,
71 Encrypt,
72 Hash,
73 Custom(u32),
74}
75
76#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub enum AcceleratorDataType {
79 F32,
80 F16,
81 BF16,
82 I32,
83 I16,
84 I8,
85 U32,
86 U16,
87 U8,
88 Bool,
89 Custom(u32),
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94pub enum AcceleratorPrecision {
95 High,
96 Medium,
97 Low,
98 Mixed,
99 Adaptive,
100}
101
102#[derive(Debug, Clone)]
104pub struct AcceleratorDevice {
105 pub id: u32,
106 pub name: String,
107 pub vendor: String,
108 pub model: String,
109 pub accelerator_type: AcceleratorType,
110 pub capabilities: AcceleratorCapabilities,
111 pub driver_version: String,
112 pub firmware_version: String,
113 pub pci_id: Option<String>,
114 pub numa_node: Option<u32>,
115}
116
117#[derive(Debug)]
119pub struct AcceleratorBuffer<T> {
120 pub ptr: *mut T,
121 pub size: usize,
122 pub device: AcceleratorDevice,
123 pub alignment: usize,
124 pub memory_type: AcceleratorMemoryType,
125 #[allow(dead_code)] backend_handle: Option<Box<dyn Any + Send + Sync>>,
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131pub enum AcceleratorMemoryType {
132 Device,
133 Host,
134 Unified,
135 Pinned,
136 Cached,
137 Uncached,
138}
139
140unsafe impl<T: Send> Send for AcceleratorBuffer<T> {}
141unsafe impl<T: Sync> Sync for AcceleratorBuffer<T> {}
142
143impl<T> Drop for AcceleratorBuffer<T> {
144 fn drop(&mut self) {
145 }
147}
148
149pub struct AcceleratorContext {
151 pub device: AcceleratorDevice,
152 pub command_queues: Vec<AcceleratorQueue>,
153 pub memory_pools: HashMap<AcceleratorMemoryType, AcceleratorMemoryPool>,
154 #[allow(dead_code)] backend_context: Option<Box<dyn Any + Send + Sync>>,
156}
157
158#[derive(Debug)]
160pub struct AcceleratorQueue {
161 pub id: u32,
162 pub priority: AcceleratorPriority,
163 pub device_id: u32,
164 #[allow(dead_code)] backend_queue: Option<Box<dyn Any + Send + Sync>>,
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq)]
170pub enum AcceleratorPriority {
171 Low,
172 Normal,
173 High,
174 Critical,
175}
176
177#[derive(Debug)]
179pub struct AcceleratorMemoryPool {
180 pub total_size: usize,
181 pub available_size: usize,
182 pub allocation_count: usize,
183 pub memory_type: AcceleratorMemoryType,
184 #[allow(dead_code)] backend_pool: Option<Box<dyn Any + Send + Sync>>,
186}
187
188#[derive(Debug, Clone)]
190pub struct AcceleratorKernel {
191 pub name: String,
192 pub operation: AcceleratorOperation,
193 pub input_buffers: Vec<u32>,
194 pub output_buffers: Vec<u32>,
195 pub parameters: HashMap<String, AcceleratorParameter>,
196 pub work_size: (usize, usize, usize),
197 pub local_size: (usize, usize, usize),
198}
199
200#[derive(Debug, Clone)]
202pub enum AcceleratorParameter {
203 Int(i64),
204 Float(f64),
205 Bool(bool),
206 String(String),
207 Array(Vec<u8>),
208}
209
210pub trait AcceleratorOperations {
212 fn allocate<T>(
214 &self,
215 size: usize,
216 memory_type: AcceleratorMemoryType,
217 ) -> Result<AcceleratorBuffer<T>, SimdError>;
218
219 fn copy_to_accelerator<T>(
221 &self,
222 host_data: &[T],
223 accel_buffer: &mut AcceleratorBuffer<T>,
224 queue: Option<&AcceleratorQueue>,
225 ) -> Result<(), SimdError>;
226
227 fn copy_from_accelerator<T>(
229 &self,
230 accel_buffer: &AcceleratorBuffer<T>,
231 host_data: &mut [T],
232 queue: Option<&AcceleratorQueue>,
233 ) -> Result<(), SimdError>;
234
235 fn execute_kernel(
237 &self,
238 kernel: &AcceleratorKernel,
239 buffers: &[&AcceleratorBuffer<u8>],
240 queue: Option<&AcceleratorQueue>,
241 ) -> Result<(), SimdError>;
242
243 fn synchronize(&self, queue: Option<&AcceleratorQueue>) -> Result<(), SimdError>;
245
246 fn get_status(&self) -> Result<AcceleratorStatus, SimdError>;
248}
249
250#[derive(Debug, Clone)]
252pub struct AcceleratorStatus {
253 pub utilization_percent: f64,
254 pub memory_usage_percent: f64,
255 pub temperature_c: f64,
256 pub power_consumption_w: f64,
257 pub clock_frequency_mhz: f64,
258 pub active_operations: Vec<AcceleratorOperation>,
259 pub error_count: u32,
260}
261
262pub struct AcceleratorRuntime {
264 devices: Vec<AcceleratorDevice>,
265 contexts: Vec<AcceleratorContext>,
266 drivers: HashMap<AcceleratorType, Box<dyn AcceleratorDriver>>,
267}
268
269pub trait AcceleratorDriver: Send + Sync {
271 fn initialize(&self) -> Result<(), SimdError>;
272 fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError>;
273 fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError>;
274 fn is_available(&self) -> bool;
275}
276
277impl AcceleratorRuntime {
278 pub fn new() -> Result<Self, SimdError> {
280 let mut runtime = Self {
281 devices: Vec::new(),
282 contexts: Vec::new(),
283 drivers: HashMap::new(),
284 };
285
286 runtime.register_driver(AcceleratorType::ASIC, Box::new(AsicDriver::new()));
288 runtime.register_driver(AcceleratorType::DSP, Box::new(DspDriver::new()));
289 runtime.register_driver(AcceleratorType::VPU, Box::new(VpuDriver::new()));
290 runtime.register_driver(AcceleratorType::NPU, Box::new(NpuDriver::new()));
291
292 runtime.discover_all_devices()?;
294
295 Ok(runtime)
296 }
297
298 pub fn register_driver(
300 &mut self,
301 accel_type: AcceleratorType,
302 driver: Box<dyn AcceleratorDriver>,
303 ) {
304 self.drivers.insert(accel_type, driver);
305 }
306
307 fn discover_all_devices(&mut self) -> Result<(), SimdError> {
309 for driver in self.drivers.values() {
310 if driver.is_available() {
311 let devices = driver.discover_devices()?;
312 self.devices.extend(devices);
313 }
314 }
315 Ok(())
316 }
317
318 pub fn devices(&self) -> &[AcceleratorDevice] {
320 &self.devices
321 }
322
323 pub fn devices_by_type(&self, accel_type: AcceleratorType) -> Vec<&AcceleratorDevice> {
325 self.devices
326 .iter()
327 .filter(|d| d.accelerator_type == accel_type)
328 .collect()
329 }
330
331 pub fn create_context(&mut self, device_id: u32) -> Result<&AcceleratorContext, SimdError> {
333 let device = self.devices.get(device_id as usize).ok_or_else(|| {
334 SimdError::InvalidArgument("Invalid accelerator device ID".to_string())
335 })?;
336
337 let driver = self
338 .drivers
339 .get(&device.accelerator_type)
340 .ok_or_else(|| SimdError::NotImplemented("Driver not available".to_string()))?;
341
342 let context = driver.create_context(device)?;
343 self.contexts.push(context);
344 Ok(self
345 .contexts
346 .last()
347 .expect("collection should not be empty"))
348 }
349
350 pub fn get_best_device(&self, operation: AcceleratorOperation) -> Option<&AcceleratorDevice> {
352 self.devices
353 .iter()
354 .filter(|d| d.capabilities.supported_operations.contains(&operation))
355 .max_by(|a, b| {
356 a.capabilities
357 .peak_performance_ops
358 .cmp(&b.capabilities.peak_performance_ops)
359 })
360 }
361}
362
363struct AsicDriver;
365struct DspDriver;
366struct VpuDriver;
367struct NpuDriver;
368
369impl AsicDriver {
370 fn new() -> Self {
371 Self
372 }
373}
374
375impl AcceleratorDriver for AsicDriver {
376 fn initialize(&self) -> Result<(), SimdError> {
377 Ok(())
378 }
379
380 fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
381 Ok(vec![])
382 }
383
384 fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
385 Ok(AcceleratorContext {
386 device: device.clone(),
387 command_queues: vec![],
388 memory_pools: HashMap::new(),
389 backend_context: None,
390 })
391 }
392
393 fn is_available(&self) -> bool {
394 false
395 }
396}
397
398impl DspDriver {
399 fn new() -> Self {
400 Self
401 }
402}
403
404impl AcceleratorDriver for DspDriver {
405 fn initialize(&self) -> Result<(), SimdError> {
406 Ok(())
407 }
408
409 fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
410 Ok(vec![])
411 }
412
413 fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
414 Ok(AcceleratorContext {
415 device: device.clone(),
416 command_queues: vec![],
417 memory_pools: HashMap::new(),
418 backend_context: None,
419 })
420 }
421
422 fn is_available(&self) -> bool {
423 false
424 }
425}
426
427impl VpuDriver {
428 fn new() -> Self {
429 Self
430 }
431}
432
433impl AcceleratorDriver for VpuDriver {
434 fn initialize(&self) -> Result<(), SimdError> {
435 Ok(())
436 }
437
438 fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
439 Ok(vec![])
440 }
441
442 fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
443 Ok(AcceleratorContext {
444 device: device.clone(),
445 command_queues: vec![],
446 memory_pools: HashMap::new(),
447 backend_context: None,
448 })
449 }
450
451 fn is_available(&self) -> bool {
452 false
453 }
454}
455
456impl NpuDriver {
457 fn new() -> Self {
458 Self
459 }
460}
461
462impl AcceleratorDriver for NpuDriver {
463 fn initialize(&self) -> Result<(), SimdError> {
464 Ok(())
465 }
466
467 fn discover_devices(&self) -> Result<Vec<AcceleratorDevice>, SimdError> {
468 Ok(vec![])
469 }
470
471 fn create_context(&self, device: &AcceleratorDevice) -> Result<AcceleratorContext, SimdError> {
472 Ok(AcceleratorContext {
473 device: device.clone(),
474 command_queues: vec![],
475 memory_pools: HashMap::new(),
476 backend_context: None,
477 })
478 }
479
480 fn is_available(&self) -> bool {
481 false
482 }
483}
484
485pub mod optimization {
487 use super::*;
488
489 pub fn optimize_kernel(
491 kernel: &AcceleratorKernel,
492 device: &AcceleratorDevice,
493 ) -> Result<AcceleratorKernel, SimdError> {
494 let mut optimized = kernel.clone();
495
496 match device.accelerator_type {
498 AcceleratorType::NPU => {
499 optimized.work_size = optimize_for_npu(kernel.work_size, &device.capabilities);
501 }
502 AcceleratorType::VPU => {
503 optimized.work_size = optimize_for_vpu(kernel.work_size, &device.capabilities);
505 }
506 AcceleratorType::DSP => {
507 optimized.work_size = optimize_for_dsp(kernel.work_size, &device.capabilities);
509 }
510 _ => {
511 optimized.work_size = optimize_generic(kernel.work_size, &device.capabilities);
513 }
514 }
515
516 Ok(optimized)
517 }
518
519 fn optimize_for_npu(
520 work_size: (usize, usize, usize),
521 caps: &AcceleratorCapabilities,
522 ) -> (usize, usize, usize) {
523 let optimal_batch = caps.max_batch_size.min(work_size.0);
524 (optimal_batch, work_size.1, work_size.2)
525 }
526
527 fn optimize_for_vpu(
528 work_size: (usize, usize, usize),
529 caps: &AcceleratorCapabilities,
530 ) -> (usize, usize, usize) {
531 let compute_units = caps.compute_units as usize;
532 let optimal_x = work_size.0.div_ceil(compute_units) * compute_units;
533 (optimal_x, work_size.1, work_size.2)
534 }
535
536 fn optimize_for_dsp(
537 work_size: (usize, usize, usize),
538 _caps: &AcceleratorCapabilities,
539 ) -> (usize, usize, usize) {
540 let next_power_of_2 = |n: usize| {
542 if n == 0 {
543 1
544 } else {
545 1 << (64 - (n - 1).leading_zeros())
546 }
547 };
548
549 (next_power_of_2(work_size.0), work_size.1, work_size.2)
550 }
551
552 fn optimize_generic(
553 work_size: (usize, usize, usize),
554 caps: &AcceleratorCapabilities,
555 ) -> (usize, usize, usize) {
556 let compute_units = caps.compute_units as usize;
557 let optimal_size = work_size.0.div_ceil(compute_units) * compute_units;
558 (optimal_size, work_size.1, work_size.2)
559 }
560
561 pub fn select_accelerator(
563 operation: AcceleratorOperation,
564 data_size: usize,
565 devices: &[AcceleratorDevice],
566 ) -> Option<&AcceleratorDevice> {
567 devices
568 .iter()
569 .filter(|d| d.capabilities.supported_operations.contains(&operation))
570 .filter(|d| d.capabilities.max_batch_size >= data_size)
571 .max_by(|a, b| {
572 let score_a = compute_device_score(a, operation, data_size);
573 let score_b = compute_device_score(b, operation, data_size);
574 score_a
575 .partial_cmp(&score_b)
576 .unwrap_or(core::cmp::Ordering::Equal)
577 })
578 }
579
580 fn compute_device_score(
581 device: &AcceleratorDevice,
582 operation: AcceleratorOperation,
583 data_size: usize,
584 ) -> f64 {
585 let mut score = 0.0;
586
587 score += device.capabilities.peak_performance_ops as f64 / 1e9;
589
590 let memory_ratio =
592 data_size as f64 / (device.capabilities.memory_mb as f64 * 1024.0 * 1024.0);
593 score += if memory_ratio <= 1.0 {
594 1.0
595 } else {
596 1.0 / memory_ratio
597 };
598
599 let ops_per_watt = device.capabilities.peak_performance_ops as f64
601 / device.capabilities.power_consumption_w;
602 score += ops_per_watt / 1e9;
603
604 match (operation, device.accelerator_type) {
606 (AcceleratorOperation::MatrixMultiply, AcceleratorType::NPU) => score += 2.0,
607 (AcceleratorOperation::Convolution, AcceleratorType::VPU) => score += 2.0,
608 (AcceleratorOperation::Transform, AcceleratorType::DSP) => score += 2.0,
609 _ => {}
610 }
611
612 score
613 }
614}
615
616#[allow(non_snake_case)]
617#[cfg(all(test, not(feature = "no-std")))]
618mod tests {
619 use super::*;
620
621 #[cfg(feature = "no-std")]
622 use alloc::{
623 string::{String, ToString},
624 vec,
625 vec::Vec,
626 };
627
628 #[test]
629 fn test_accelerator_runtime_creation() {
630 let runtime = AcceleratorRuntime::new();
631 assert!(runtime.is_ok());
632 }
633
634 #[test]
635 fn test_accelerator_type_display() {
636 let types = vec![
637 AcceleratorType::ASIC,
638 AcceleratorType::Custom,
639 AcceleratorType::DSP,
640 AcceleratorType::VPU,
641 AcceleratorType::NPU,
642 ];
643
644 for accel_type in types {
645 println!("Accelerator type: {:?}", accel_type);
646 }
647 }
648
649 #[test]
650 fn test_accelerator_capabilities() {
651 let caps = AcceleratorCapabilities {
652 supported_operations: vec![
653 AcceleratorOperation::MatrixMultiply,
654 AcceleratorOperation::Convolution,
655 ],
656 data_types: vec![AcceleratorDataType::F32, AcceleratorDataType::F16],
657 max_batch_size: 1024,
658 memory_mb: 8192,
659 compute_units: 64,
660 peak_performance_ops: 1000000000,
661 power_consumption_w: 150.0,
662 precision_modes: vec![AcceleratorPrecision::High, AcceleratorPrecision::Mixed],
663 };
664
665 assert_eq!(caps.supported_operations.len(), 2);
666 assert_eq!(caps.data_types.len(), 2);
667 assert_eq!(caps.max_batch_size, 1024);
668 }
669
670 #[test]
671 fn test_accelerator_kernel() {
672 let mut params = HashMap::new();
673 params.insert("alpha".to_string(), AcceleratorParameter::Float(1.0));
674 params.insert("beta".to_string(), AcceleratorParameter::Float(0.0));
675
676 let kernel = AcceleratorKernel {
677 name: "test_kernel".to_string(),
678 operation: AcceleratorOperation::MatrixMultiply,
679 input_buffers: vec![0, 1],
680 output_buffers: vec![2],
681 parameters: params,
682 work_size: (1024, 1024, 1),
683 local_size: (16, 16, 1),
684 };
685
686 assert_eq!(kernel.name, "test_kernel");
687 assert_eq!(kernel.operation, AcceleratorOperation::MatrixMultiply);
688 assert_eq!(kernel.input_buffers.len(), 2);
689 assert_eq!(kernel.output_buffers.len(), 1);
690 assert_eq!(kernel.parameters.len(), 2);
691 }
692
693 #[test]
694 fn test_accelerator_device() {
695 let device = AcceleratorDevice {
696 id: 0,
697 name: "Test Accelerator".to_string(),
698 vendor: "Test Vendor".to_string(),
699 model: "Test Model".to_string(),
700 accelerator_type: AcceleratorType::NPU,
701 capabilities: AcceleratorCapabilities {
702 supported_operations: vec![AcceleratorOperation::MatrixMultiply],
703 data_types: vec![AcceleratorDataType::F32],
704 max_batch_size: 512,
705 memory_mb: 4096,
706 compute_units: 32,
707 peak_performance_ops: 500000000,
708 power_consumption_w: 100.0,
709 precision_modes: vec![AcceleratorPrecision::High],
710 },
711 driver_version: "1.0.0".to_string(),
712 firmware_version: "2.0.0".to_string(),
713 pci_id: Some("1234:5678".to_string()),
714 numa_node: Some(0),
715 };
716
717 assert_eq!(device.id, 0);
718 assert_eq!(device.accelerator_type, AcceleratorType::NPU);
719 assert_eq!(device.capabilities.compute_units, 32);
720 }
721
722 #[test]
723 fn test_kernel_optimization() {
724 let device = AcceleratorDevice {
725 id: 0,
726 name: "Test NPU".to_string(),
727 vendor: "Test".to_string(),
728 model: "NPU-100".to_string(),
729 accelerator_type: AcceleratorType::NPU,
730 capabilities: AcceleratorCapabilities {
731 supported_operations: vec![AcceleratorOperation::MatrixMultiply],
732 data_types: vec![AcceleratorDataType::F32],
733 max_batch_size: 256,
734 memory_mb: 2048,
735 compute_units: 16,
736 peak_performance_ops: 100000000,
737 power_consumption_w: 50.0,
738 precision_modes: vec![AcceleratorPrecision::High],
739 },
740 driver_version: "1.0.0".to_string(),
741 firmware_version: "1.0.0".to_string(),
742 pci_id: None,
743 numa_node: None,
744 };
745
746 let kernel = AcceleratorKernel {
747 name: "test".to_string(),
748 operation: AcceleratorOperation::MatrixMultiply,
749 input_buffers: vec![0, 1],
750 output_buffers: vec![2],
751 parameters: HashMap::new(),
752 work_size: (512, 512, 1),
753 local_size: (16, 16, 1),
754 };
755
756 let optimized = optimization::optimize_kernel(&kernel, &device);
757 assert!(optimized.is_ok());
758
759 let opt_kernel = optimized.expect("operation should succeed");
760 assert_eq!(opt_kernel.work_size.0, 256); }
762
763 #[test]
764 fn test_accelerator_selection() {
765 let devices = vec![
766 AcceleratorDevice {
767 id: 0,
768 name: "NPU".to_string(),
769 vendor: "Test".to_string(),
770 model: "NPU-100".to_string(),
771 accelerator_type: AcceleratorType::NPU,
772 capabilities: AcceleratorCapabilities {
773 supported_operations: vec![AcceleratorOperation::MatrixMultiply],
774 data_types: vec![AcceleratorDataType::F32],
775 max_batch_size: 1024,
776 memory_mb: 4096,
777 compute_units: 32,
778 peak_performance_ops: 1000000000,
779 power_consumption_w: 100.0,
780 precision_modes: vec![AcceleratorPrecision::High],
781 },
782 driver_version: "1.0.0".to_string(),
783 firmware_version: "1.0.0".to_string(),
784 pci_id: None,
785 numa_node: None,
786 },
787 AcceleratorDevice {
788 id: 1,
789 name: "VPU".to_string(),
790 vendor: "Test".to_string(),
791 model: "VPU-200".to_string(),
792 accelerator_type: AcceleratorType::VPU,
793 capabilities: AcceleratorCapabilities {
794 supported_operations: vec![AcceleratorOperation::Convolution],
795 data_types: vec![AcceleratorDataType::F16],
796 max_batch_size: 512,
797 memory_mb: 2048,
798 compute_units: 16,
799 peak_performance_ops: 500000000,
800 power_consumption_w: 50.0,
801 precision_modes: vec![AcceleratorPrecision::Mixed],
802 },
803 driver_version: "1.0.0".to_string(),
804 firmware_version: "1.0.0".to_string(),
805 pci_id: None,
806 numa_node: None,
807 },
808 ];
809
810 let selected =
811 optimization::select_accelerator(AcceleratorOperation::MatrixMultiply, 512, &devices);
812 assert!(selected.is_some());
813 assert_eq!(selected.expect("operation should succeed").id, 0);
814
815 let selected =
816 optimization::select_accelerator(AcceleratorOperation::Convolution, 256, &devices);
817 assert!(selected.is_some());
818 assert_eq!(selected.expect("operation should succeed").id, 1);
819 }
820}