Skip to main content

trustformers_core/hardware/
mod.rs

1// Copyright (c) 2025-2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Hardware acceleration abstraction layer for TrustformeRS
5//!
6//! This module provides a unified interface for various hardware acceleration
7//! platforms, including custom ASICs, neuromorphic processors, and specialized
8//! AI accelerators. It enables seamless integration of new hardware backends
9//! while maintaining compatibility with existing tensor operations.
10
11pub mod allocation;
12pub mod asic;
13pub mod backends;
14pub mod config;
15pub mod devices;
16pub mod manager;
17pub mod monitoring;
18pub mod registry;
19pub mod scheduling;
20pub mod traits;
21
22pub use allocation::{LoadBalancer, MemoryManager, ResourceAllocator};
23pub use asic::{AsicBackend, AsicDevice, AsicOperationSet};
24pub use backends::{CPUBackend, CPUBackendConfig, GPUBackend, GPUBackendConfig};
25pub use config::{AllocationStrategy, LoadBalancingStrategy};
26pub use config::{DeviceInfo, HardwareManagerConfig};
27pub use devices::{CPUDevice, GPUBackendType, GPUDevice};
28pub use manager::HardwareManager;
29pub use monitoring::{
30    AnomalyDetector, AnomalySeverity, AnomalyType, HealthChecker, HealthStatus, PerformanceMonitor,
31};
32pub use registry::HardwareRegistry;
33pub use scheduling::{AdvancedScheduler, DefaultScheduler, SchedulingAlgorithm};
34
35use crate::errors::TrustformersError;
36use serde::{Deserialize, Serialize};
37pub use traits::{
38    HardwareBackend, HardwareDevice, HardwareOperation, HardwareScheduler, OperationParameter,
39    SchedulerStatistics,
40};
41
42/// Supported hardware types in TrustformeRS
43#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
44pub enum HardwareType {
45    /// Central Processing Unit
46    CPU,
47    /// Graphics Processing Unit (CUDA, ROCm, Metal, etc.)
48    GPU,
49    /// Custom Application-Specific Integrated Circuit
50    ASIC,
51    /// Neuromorphic processing unit
52    Neuromorphic,
53    /// Quantum processing unit
54    Quantum,
55    /// Field-Programmable Gate Array
56    FPGA,
57    /// Digital Signal Processor
58    DSP,
59    /// Tensor Processing Unit
60    TPU,
61    /// Vision Processing Unit
62    VPU,
63    /// Custom accelerator
64    Custom(String),
65}
66
67/// Hardware capability flags
68#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
69pub struct HardwareCapabilities {
70    /// Supported data types
71    pub data_types: Vec<DataType>,
72    /// Maximum tensor dimensions
73    pub max_dimensions: usize,
74    /// Memory size in bytes
75    pub memory_size: Option<usize>,
76    /// Clock frequency in Hz
77    pub clock_frequency: Option<u64>,
78    /// Compute units
79    pub compute_units: Option<u32>,
80    /// Supported operations
81    pub operations: Vec<String>,
82    /// Power consumption in watts
83    pub power_consumption: Option<f64>,
84    /// Thermal design power
85    pub thermal_design_power: Option<f64>,
86}
87
88/// Supported data types for hardware operations
89#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
90#[repr(C)]
91pub enum DataType {
92    F32,
93    F16,
94    BF16,
95    F64,
96    I8,
97    I16,
98    I32,
99    I64,
100    U8,
101    U16,
102    U32,
103    U64,
104    Bool,
105    Complex64,
106    Complex128,
107    Custom(u8), // Custom bit width
108}
109
110/// Hardware performance metrics
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct HardwareMetrics {
113    /// Operations per second
114    pub ops_per_second: f64,
115    /// Memory bandwidth in bytes/second
116    pub memory_bandwidth: f64,
117    /// Utilization percentage (0.0 to 1.0)
118    pub utilization: f64,
119    /// Power consumption in watts
120    pub power_consumption: f64,
121    /// Temperature in Celsius
122    pub temperature: Option<f64>,
123    /// Error rate
124    pub error_rate: f64,
125    /// Latency in microseconds
126    pub latency: f64,
127    /// Throughput in operations per second
128    pub throughput: f64,
129}
130
131/// Hardware configuration for different operation modes
132#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
133pub struct HardwareConfig {
134    /// Hardware type
135    pub hardware_type: HardwareType,
136    /// Device identifier
137    pub device_id: String,
138    /// Operation mode (Performance, Efficiency, Balanced)
139    pub operation_mode: OperationMode,
140    /// Memory pool size
141    pub memory_pool_size: Option<usize>,
142    /// Batch size limits
143    pub batch_size_limits: Option<(usize, usize)>,
144    /// Precision mode
145    pub precision_mode: PrecisionMode,
146    /// Custom parameters
147    pub custom_params: std::collections::HashMap<String, String>,
148}
149
150/// Operation modes for hardware optimization
151#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
152pub enum OperationMode {
153    /// Maximum performance
154    Performance,
155    /// Maximum efficiency
156    Efficiency,
157    /// Balanced performance and efficiency
158    Balanced,
159    /// Low power consumption
160    LowPower,
161    /// High precision
162    HighPrecision,
163    /// Custom mode
164    Custom,
165}
166
167/// Precision modes for hardware operations
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
169pub enum PrecisionMode {
170    /// Single precision floating point
171    Single,
172    /// Half precision floating point
173    Half,
174    /// Brain floating point
175    BFloat16,
176    /// Double precision floating point
177    Double,
178    /// Mixed precision
179    Mixed,
180    /// Integer precision
181    Integer(u8),
182    /// Custom precision
183    Custom(u8),
184}
185
186impl Default for HardwareCapabilities {
187    fn default() -> Self {
188        Self {
189            data_types: vec![DataType::F32],
190            max_dimensions: 8,
191            memory_size: None,
192            clock_frequency: None,
193            compute_units: None,
194            operations: vec![],
195            power_consumption: None,
196            thermal_design_power: None,
197        }
198    }
199}
200
201impl Default for HardwareConfig {
202    fn default() -> Self {
203        Self {
204            hardware_type: HardwareType::CPU,
205            device_id: "default".to_string(),
206            operation_mode: OperationMode::Balanced,
207            memory_pool_size: None,
208            batch_size_limits: None,
209            precision_mode: PrecisionMode::Single,
210            custom_params: std::collections::HashMap::new(),
211        }
212    }
213}
214
215impl std::fmt::Display for HardwareType {
216    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
217        match self {
218            HardwareType::CPU => write!(f, "CPU"),
219            HardwareType::GPU => write!(f, "GPU"),
220            HardwareType::ASIC => write!(f, "ASIC"),
221            HardwareType::Neuromorphic => write!(f, "Neuromorphic"),
222            HardwareType::Quantum => write!(f, "Quantum"),
223            HardwareType::FPGA => write!(f, "FPGA"),
224            HardwareType::DSP => write!(f, "DSP"),
225            HardwareType::TPU => write!(f, "TPU"),
226            HardwareType::VPU => write!(f, "VPU"),
227            HardwareType::Custom(name) => write!(f, "Custom({})", name),
228        }
229    }
230}
231
232impl std::fmt::Display for DataType {
233    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234        match self {
235            DataType::F32 => write!(f, "f32"),
236            DataType::F16 => write!(f, "f16"),
237            DataType::BF16 => write!(f, "bf16"),
238            DataType::F64 => write!(f, "f64"),
239            DataType::I8 => write!(f, "i8"),
240            DataType::I16 => write!(f, "i16"),
241            DataType::I32 => write!(f, "i32"),
242            DataType::I64 => write!(f, "i64"),
243            DataType::U8 => write!(f, "u8"),
244            DataType::U16 => write!(f, "u16"),
245            DataType::U32 => write!(f, "u32"),
246            DataType::U64 => write!(f, "u64"),
247            DataType::Bool => write!(f, "bool"),
248            DataType::Complex64 => write!(f, "complex64"),
249            DataType::Complex128 => write!(f, "complex128"),
250            DataType::Custom(bits) => write!(f, "custom({})", bits),
251        }
252    }
253}
254
255/// Hardware abstraction result type
256pub type HardwareResult<T> = Result<T, TrustformersError>;
257
258#[cfg(test)]
259mod tests {
260    use super::asic::*;
261    use super::traits::DeviceStatus as TraitsDeviceStatus;
262    use super::*;
263
264    use std::collections::HashMap;
265
266    #[test]
267    fn test_hardware_type_display() {
268        assert_eq!(HardwareType::CPU.to_string(), "CPU");
269        assert_eq!(HardwareType::ASIC.to_string(), "ASIC");
270        assert_eq!(
271            HardwareType::Custom("TPU".to_string()).to_string(),
272            "Custom(TPU)"
273        );
274    }
275
276    #[test]
277    fn test_data_type_display() {
278        assert_eq!(DataType::F32.to_string(), "f32");
279        assert_eq!(DataType::BF16.to_string(), "bf16");
280        assert_eq!(DataType::Custom(8).to_string(), "custom(8)");
281    }
282
283    #[test]
284    fn test_hardware_capabilities_default() {
285        let caps = HardwareCapabilities::default();
286        assert_eq!(caps.data_types, vec![DataType::F32]);
287        assert_eq!(caps.max_dimensions, 8);
288        assert!(caps.memory_size.is_none());
289    }
290
291    #[test]
292    fn test_hardware_config_default() {
293        let config = HardwareConfig::default();
294        assert_eq!(config.hardware_type, HardwareType::CPU);
295        assert_eq!(config.device_id, "default");
296        assert_eq!(config.operation_mode, OperationMode::Balanced);
297        assert_eq!(config.precision_mode, PrecisionMode::Single);
298    }
299
300    #[test]
301    fn test_hardware_types_equality() {
302        assert_eq!(HardwareType::CPU, HardwareType::CPU);
303        assert_ne!(HardwareType::CPU, HardwareType::GPU);
304        assert_eq!(
305            HardwareType::Custom("TPU".to_string()),
306            HardwareType::Custom("TPU".to_string())
307        );
308    }
309
310    #[test]
311    fn test_asic_type_varieties() {
312        let asic_types = [
313            AsicType::AIInference,
314            AsicType::NPU,
315            AsicType::TPU,
316            AsicType::DSP,
317            AsicType::VPU,
318            AsicType::Crypto,
319            AsicType::EdgeAI,
320            AsicType::Custom("CustomAccelerator".to_string()),
321        ];
322
323        assert_eq!(asic_types.len(), 8);
324        assert_eq!(asic_types[0], AsicType::AIInference);
325        assert_eq!(
326            asic_types[7],
327            AsicType::Custom("CustomAccelerator".to_string())
328        );
329    }
330
331    #[test]
332    fn test_asic_vendor_creation() {
333        let vendor = AsicVendor {
334            name: "TrustformeRS Chips".to_string(),
335            id: 0x1234,
336            driver_version: "2.1.0".to_string(),
337            firmware_version: "1.5.2".to_string(),
338            support_contact: Some("support@trustformers.ai".to_string()),
339        };
340
341        assert_eq!(vendor.name, "TrustformeRS Chips");
342        assert_eq!(vendor.id, 0x1234);
343        assert!(vendor.support_contact.is_some());
344    }
345
346    #[test]
347    fn test_device_status_and_memory_usage() {
348        use super::traits::MemoryUsage;
349
350        let memory_usage = MemoryUsage {
351            total: 8192,
352            used: 4096,
353            free: 4096,
354            fragmentation: 0.1,
355        };
356
357        let status = TraitsDeviceStatus {
358            online: true,
359            busy: false,
360            error: None,
361            memory_usage,
362            temperature: Some(70.5),
363            power_consumption: Some(150.0),
364            utilization: 0.8,
365        };
366
367        assert!(status.online);
368        assert!(!status.busy);
369        assert!(status.error.is_none());
370        assert_eq!(status.memory_usage.total, 8192);
371        assert_eq!(status.memory_usage.used, 4096);
372        assert_eq!(status.memory_usage.free, 4096);
373        assert_eq!(status.temperature, Some(70.5));
374        assert_eq!(status.utilization, 0.8);
375    }
376
377    #[test]
378    fn test_operation_parameters() {
379        use super::traits::OperationParameter;
380
381        let mut params = HashMap::new();
382        params.insert(
383            "learning_rate".to_string(),
384            OperationParameter::Float(0.001),
385        );
386        params.insert("batch_size".to_string(), OperationParameter::Integer(32));
387        params.insert(
388            "model_name".to_string(),
389            OperationParameter::String("bert-base".to_string()),
390        );
391        params.insert("use_fp16".to_string(), OperationParameter::Boolean(true));
392
393        let array_param = OperationParameter::Array(vec![
394            OperationParameter::Integer(1),
395            OperationParameter::Integer(2),
396            OperationParameter::Integer(3),
397        ]);
398        params.insert("dimensions".to_string(), array_param);
399
400        assert_eq!(params.len(), 5);
401
402        match params.get("learning_rate").expect("expected value not found") {
403            OperationParameter::Float(val) => assert_eq!(*val, 0.001),
404            _ => panic!(
405                "Expected Float parameter but got {:?}",
406                params.get("learning_rate")
407            ),
408        }
409
410        match params.get("batch_size").expect("expected value not found") {
411            OperationParameter::Integer(val) => assert_eq!(*val, 32),
412            _ => panic!(
413                "Expected Integer parameter but got {:?}",
414                params.get("batch_size")
415            ),
416        }
417    }
418
419    #[test]
420    fn test_memory_types() {
421        use super::traits::{DeviceMemory, MemoryType};
422
423        let memory_types = [
424            MemoryType::Local,
425            MemoryType::Host,
426            MemoryType::Shared,
427            MemoryType::Unified,
428            MemoryType::Persistent,
429            MemoryType::Cache,
430        ];
431
432        assert_eq!(memory_types.len(), 6);
433        assert_eq!(memory_types[0], MemoryType::Local);
434        assert_ne!(memory_types[0], MemoryType::Host);
435
436        let device_memory = DeviceMemory {
437            address: 0x10000000,
438            size: 1024 * 1024, // 1MB
439            memory_type: MemoryType::Local,
440            device_id: "gpu_0".to_string(),
441        };
442
443        assert_eq!(device_memory.address, 0x10000000);
444        assert_eq!(device_memory.size, 1024 * 1024);
445        assert_eq!(device_memory.memory_type, MemoryType::Local);
446        assert_eq!(device_memory.device_id, "gpu_0");
447    }
448
449    #[test]
450    fn test_hardware_metrics() {
451        let metrics = HardwareMetrics {
452            ops_per_second: 1000.0,
453            memory_bandwidth: 500.0,
454            utilization: 0.5,
455            power_consumption: 100.0,
456            temperature: Some(65.0),
457            error_rate: 0.001,
458            latency: 10.0,
459            throughput: 1000.0,
460        };
461
462        assert_eq!(metrics.ops_per_second, 1000.0);
463        assert_eq!(metrics.utilization, 0.5);
464        assert_eq!(metrics.temperature, Some(65.0));
465        assert!(metrics.error_rate < 0.01);
466    }
467
468    #[test]
469    fn test_precision_modes() {
470        let precision_modes = [
471            PrecisionMode::Single,
472            PrecisionMode::Half,
473            PrecisionMode::BFloat16,
474            PrecisionMode::Double,
475            PrecisionMode::Mixed,
476            PrecisionMode::Integer(8),
477            PrecisionMode::Custom(12),
478        ];
479
480        assert_eq!(precision_modes.len(), 7);
481        assert_eq!(precision_modes[0], PrecisionMode::Single);
482        assert_eq!(precision_modes[5], PrecisionMode::Integer(8));
483        assert_eq!(precision_modes[6], PrecisionMode::Custom(12));
484    }
485
486    #[test]
487    fn test_operation_modes() {
488        let operation_modes = [
489            OperationMode::Performance,
490            OperationMode::Efficiency,
491            OperationMode::Balanced,
492            OperationMode::LowPower,
493            OperationMode::HighPrecision,
494            OperationMode::Custom,
495        ];
496
497        assert_eq!(operation_modes.len(), 6);
498        assert_eq!(operation_modes[0], OperationMode::Performance);
499        assert_eq!(operation_modes[2], OperationMode::Balanced);
500        assert_eq!(operation_modes[5], OperationMode::Custom);
501    }
502
503    #[test]
504    fn test_hardware_serialization() {
505        // Test HardwareType serialization
506        let hardware_type = HardwareType::Custom("TestAccelerator".to_string());
507        let serialized = serde_json::to_string(&hardware_type).expect("JSON serialization failed");
508        let deserialized: HardwareType =
509            serde_json::from_str(&serialized).expect("JSON deserialization failed");
510        assert_eq!(hardware_type, deserialized);
511
512        // Test DataType serialization
513        let data_type = DataType::Custom(12);
514        let serialized = serde_json::to_string(&data_type).expect("JSON serialization failed");
515        let deserialized: DataType =
516            serde_json::from_str(&serialized).expect("JSON deserialization failed");
517        assert_eq!(data_type, deserialized);
518
519        // Test OperationMode serialization
520        let operation_mode = OperationMode::Performance;
521        let serialized = serde_json::to_string(&operation_mode).expect("JSON serialization failed");
522        let deserialized: OperationMode =
523            serde_json::from_str(&serialized).expect("JSON deserialization failed");
524        assert_eq!(operation_mode, deserialized);
525    }
526
527    #[test]
528    fn test_hardware_capabilities_custom() {
529        let caps = HardwareCapabilities {
530            data_types: vec![DataType::F32, DataType::F16, DataType::I8],
531            max_dimensions: 16,
532            memory_size: Some(8 * 1024 * 1024 * 1024), // 8GB
533            clock_frequency: Some(2_500_000_000),      // 2.5 GHz
534            compute_units: Some(64),
535            operations: vec![
536                "matmul".to_string(),
537                "conv2d".to_string(),
538                "attention".to_string(),
539            ],
540            power_consumption: Some(250.0),
541            thermal_design_power: Some(300.0),
542        };
543
544        assert_eq!(caps.data_types.len(), 3);
545        assert_eq!(caps.max_dimensions, 16);
546        assert_eq!(caps.memory_size, Some(8 * 1024 * 1024 * 1024));
547        assert_eq!(caps.operations.len(), 3);
548        assert!(caps.operations.contains(&"matmul".to_string()));
549    }
550
551    #[test]
552    fn test_hardware_config_custom() {
553        let mut custom_params = HashMap::new();
554        custom_params.insert("vendor".to_string(), "TrustformeRS".to_string());
555        custom_params.insert("model".to_string(), "TF-1000".to_string());
556        custom_params.insert("revision".to_string(), "A1".to_string());
557
558        let config = HardwareConfig {
559            hardware_type: HardwareType::ASIC,
560            device_id: "asic_0".to_string(),
561            operation_mode: OperationMode::Performance,
562            memory_pool_size: Some(1024 * 1024 * 1024), // 1GB
563            batch_size_limits: Some((1, 256)),
564            precision_mode: PrecisionMode::Mixed,
565            custom_params,
566        };
567
568        assert_eq!(config.hardware_type, HardwareType::ASIC);
569        assert_eq!(config.device_id, "asic_0");
570        assert_eq!(config.operation_mode, OperationMode::Performance);
571        assert_eq!(config.memory_pool_size, Some(1024 * 1024 * 1024));
572        assert_eq!(config.batch_size_limits, Some((1, 256)));
573        assert_eq!(config.precision_mode, PrecisionMode::Mixed);
574        assert_eq!(config.custom_params.len(), 3);
575        assert_eq!(
576            config.custom_params.get("vendor"),
577            Some(&"TrustformeRS".to_string())
578        );
579    }
580}