Skip to main content

trustformers_core/hardware/
traits.rs

1// Copyright (c) 2025-2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! Hardware abstraction traits for TrustformeRS
5//!
6//! This module defines the core traits that hardware backends must implement
7//! to integrate with the TrustformeRS ecosystem. These traits provide a unified
8//! interface for tensor operations, memory management, and device control.
9
10use super::{HardwareCapabilities, HardwareConfig, HardwareMetrics, HardwareResult, HardwareType};
11use crate::tensor::Tensor;
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16/// Core trait for hardware devices
17#[async_trait]
18pub trait HardwareDevice: Send + Sync {
19    /// Get device identifier
20    fn device_id(&self) -> &str;
21
22    /// Get hardware type
23    fn hardware_type(&self) -> HardwareType;
24
25    /// Get device capabilities
26    fn capabilities(&self) -> &HardwareCapabilities;
27
28    /// Initialize the device
29    async fn initialize(&mut self, config: &HardwareConfig) -> HardwareResult<()>;
30
31    /// Shutdown the device
32    async fn shutdown(&mut self) -> HardwareResult<()>;
33
34    /// Check if device is available
35    fn is_available(&self) -> bool;
36
37    /// Get current device status
38    fn status(&self) -> DeviceStatus;
39
40    /// Get current metrics
41    async fn metrics(&self) -> HardwareResult<HardwareMetrics>;
42
43    /// Reset device state
44    async fn reset(&mut self) -> HardwareResult<()>;
45
46    /// Allocate memory on device
47    async fn allocate_memory(&mut self, size: usize) -> HardwareResult<DeviceMemory>;
48
49    /// Free memory on device
50    async fn free_memory(&mut self, memory: DeviceMemory) -> HardwareResult<()>;
51
52    /// Synchronize device operations
53    async fn synchronize(&self) -> HardwareResult<()>;
54}
55
56/// Backend trait for hardware implementations
57#[async_trait]
58pub trait HardwareBackend: Send + Sync {
59    /// Backend name
60    fn name(&self) -> &str;
61
62    /// Backend version
63    fn version(&self) -> &str;
64
65    /// Discover available devices
66    async fn discover_devices(&self) -> HardwareResult<Vec<Box<dyn HardwareDevice>>>;
67
68    /// Create device from configuration
69    async fn create_device(
70        &self,
71        config: &HardwareConfig,
72    ) -> HardwareResult<Box<dyn HardwareDevice>>;
73
74    /// Check backend compatibility
75    fn is_compatible(&self, hardware_type: HardwareType) -> bool;
76
77    /// Get supported operations
78    fn supported_operations(&self) -> &[String];
79
80    /// Validate configuration
81    fn validate_config(&self, config: &HardwareConfig) -> HardwareResult<()>;
82}
83
84/// Hardware operation trait
85#[async_trait]
86pub trait HardwareOperation: Send + Sync {
87    /// Operation name
88    fn name(&self) -> &str;
89
90    /// Execute operation on device
91    async fn execute(
92        &self,
93        device: &mut dyn HardwareDevice,
94        inputs: &[Tensor],
95        outputs: &mut [Tensor],
96        params: &HashMap<String, OperationParameter>,
97    ) -> HardwareResult<()>;
98
99    /// Validate operation parameters
100    fn validate_params(&self, params: &HashMap<String, OperationParameter>) -> HardwareResult<()>;
101
102    /// Get operation requirements
103    fn requirements(&self) -> OperationRequirements;
104
105    /// Estimate operation cost
106    fn estimate_cost(&self, inputs: &[Tensor], params: &HashMap<String, OperationParameter>)
107        -> f64;
108}
109
110/// Device status information
111#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct DeviceStatus {
113    /// Device is online and ready
114    pub online: bool,
115    /// Device is busy processing
116    pub busy: bool,
117    /// Error state
118    pub error: Option<String>,
119    /// Memory usage
120    pub memory_usage: MemoryUsage,
121    /// Temperature
122    pub temperature: Option<f64>,
123    /// Power consumption
124    pub power_consumption: Option<f64>,
125    /// Utilization percentage
126    pub utilization: f64,
127}
128
129/// Memory usage information
130#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct MemoryUsage {
132    /// Total memory in bytes
133    pub total: usize,
134    /// Used memory in bytes
135    pub used: usize,
136    /// Free memory in bytes
137    pub free: usize,
138    /// Fragmentation ratio
139    pub fragmentation: f64,
140}
141
142/// Device memory handle
143#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144pub struct DeviceMemory {
145    /// Memory address
146    pub address: usize,
147    /// Memory size in bytes
148    pub size: usize,
149    /// Memory type
150    pub memory_type: MemoryType,
151    /// Device identifier
152    pub device_id: String,
153}
154
155/// Memory types
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
157pub enum MemoryType {
158    /// Device local memory
159    Local,
160    /// Host memory
161    Host,
162    /// Shared memory
163    Shared,
164    /// Unified memory
165    Unified,
166    /// Persistent memory
167    Persistent,
168    /// Cache memory
169    Cache,
170}
171
172/// Operation parameter types
173#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
174pub enum OperationParameter {
175    /// Integer parameter
176    Integer(i64),
177    /// Float parameter
178    Float(f64),
179    /// String parameter
180    String(String),
181    /// Boolean parameter
182    Boolean(bool),
183    /// Array parameter
184    Array(Vec<OperationParameter>),
185    /// Object parameter
186    Object(HashMap<String, OperationParameter>),
187}
188
189/// Operation requirements
190#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
191pub struct OperationRequirements {
192    /// Minimum memory required
193    pub min_memory: usize,
194    /// Required compute units
195    pub compute_units: Option<u32>,
196    /// Required data types
197    pub data_types: Vec<super::DataType>,
198    /// Required capabilities
199    pub capabilities: Vec<String>,
200    /// Performance characteristics
201    pub performance: PerformanceRequirements,
202}
203
204/// Performance requirements
205#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
206pub struct PerformanceRequirements {
207    /// Maximum latency in milliseconds
208    pub max_latency: Option<f64>,
209    /// Minimum throughput in operations per second
210    pub min_throughput: Option<f64>,
211    /// Memory bandwidth requirements
212    pub memory_bandwidth: Option<f64>,
213    /// Power consumption limit
214    pub power_limit: Option<f64>,
215}
216
217/// Async hardware operation trait
218#[async_trait]
219pub trait AsyncHardwareOperation: Send + Sync {
220    /// Start async operation
221    async fn start(
222        &self,
223        device: &mut dyn HardwareDevice,
224        inputs: &[Tensor],
225        params: &HashMap<String, OperationParameter>,
226    ) -> HardwareResult<AsyncOperationHandle>;
227
228    /// Check operation status
229    async fn status(&self, handle: &AsyncOperationHandle) -> HardwareResult<AsyncOperationStatus>;
230
231    /// Get operation results
232    async fn results(&self, handle: &AsyncOperationHandle) -> HardwareResult<Vec<Tensor>>;
233
234    /// Cancel operation
235    async fn cancel(&self, handle: &AsyncOperationHandle) -> HardwareResult<()>;
236}
237
238/// Async operation handle
239#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
240pub struct AsyncOperationHandle {
241    /// Operation identifier
242    pub id: String,
243    /// Device identifier
244    pub device_id: String,
245    /// Operation name
246    pub operation_name: String,
247    /// Start time
248    pub start_time: std::time::SystemTime,
249}
250
251/// Async operation status
252#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
253pub enum AsyncOperationStatus {
254    /// Operation is queued
255    Queued,
256    /// Operation is running
257    Running,
258    /// Operation completed successfully
259    Completed,
260    /// Operation failed
261    Failed(String),
262    /// Operation was cancelled
263    Cancelled,
264}
265
266/// Hardware scheduler trait
267pub trait HardwareScheduler: Send + Sync + std::fmt::Debug {
268    /// Schedule operation on best available device
269    fn schedule_operation(
270        &self,
271        operation: &dyn HardwareOperation,
272        inputs: &[Tensor],
273        params: &HashMap<String, OperationParameter>,
274    ) -> HardwareResult<String>; // Returns device_id
275
276    /// Get scheduler statistics
277    fn statistics(&self) -> SchedulerStatistics;
278
279    /// Update device priorities
280    fn update_priorities(&mut self, priorities: HashMap<String, f64>);
281}
282
283/// Scheduler statistics
284#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
285pub struct SchedulerStatistics {
286    /// Total operations scheduled
287    pub total_operations: u64,
288    /// Operations per device
289    pub operations_per_device: HashMap<String, u64>,
290    /// Average scheduling time
291    pub avg_scheduling_time: f64,
292    /// Device utilization
293    pub device_utilization: HashMap<String, f64>,
294    /// Failed operations
295    pub failed_operations: u64,
296}
297
298impl Default for DeviceStatus {
299    fn default() -> Self {
300        Self {
301            online: false,
302            busy: false,
303            error: None,
304            memory_usage: MemoryUsage::default(),
305            temperature: None,
306            power_consumption: None,
307            utilization: 0.0,
308        }
309    }
310}
311
312impl Default for MemoryUsage {
313    fn default() -> Self {
314        Self {
315            total: 0,
316            used: 0,
317            free: 0,
318            fragmentation: 0.0,
319        }
320    }
321}
322
323impl Default for OperationRequirements {
324    fn default() -> Self {
325        Self {
326            min_memory: 0,
327            compute_units: None,
328            data_types: vec![super::DataType::F32],
329            capabilities: vec![],
330            performance: PerformanceRequirements::default(),
331        }
332    }
333}
334
335impl Default for SchedulerStatistics {
336    fn default() -> Self {
337        Self {
338            total_operations: 0,
339            operations_per_device: HashMap::new(),
340            avg_scheduling_time: 0.0,
341            device_utilization: HashMap::new(),
342            failed_operations: 0,
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use super::*;
350
351    #[test]
352    fn test_device_status_default() {
353        let status = DeviceStatus::default();
354        assert!(!status.online);
355        assert!(!status.busy);
356        assert!(status.error.is_none());
357        assert_eq!(status.utilization, 0.0);
358    }
359
360    #[test]
361    fn test_memory_usage_calculation() {
362        let mut usage = MemoryUsage::default();
363        usage.total = 1000;
364        usage.used = 600;
365        usage.free = usage.total - usage.used;
366        assert_eq!(usage.free, 400);
367    }
368
369    #[test]
370    fn test_operation_parameter_types() {
371        let int_param = OperationParameter::Integer(42);
372        let float_param = OperationParameter::Float(std::f64::consts::PI);
373        let _string_param = OperationParameter::String("test".to_string());
374        let _bool_param = OperationParameter::Boolean(true);
375
376        match int_param {
377            OperationParameter::Integer(val) => assert_eq!(val, 42),
378            _ => panic!("Expected Integer parameter but got {:?}", int_param),
379        }
380
381        match float_param {
382            OperationParameter::Float(val) => assert_eq!(val, std::f64::consts::PI),
383            _ => panic!("Expected Float parameter but got {:?}", float_param),
384        }
385    }
386
387    #[test]
388    fn test_async_operation_status() {
389        let status = AsyncOperationStatus::Queued;
390        assert_eq!(status, AsyncOperationStatus::Queued);
391
392        let failed_status = AsyncOperationStatus::Failed("test error".to_string());
393        match failed_status {
394            AsyncOperationStatus::Failed(msg) => assert_eq!(msg, "test error"),
395            _ => panic!("Expected Failed status but got {:?}", failed_status),
396        }
397    }
398
399    #[test]
400    fn test_memory_type_equality() {
401        assert_eq!(MemoryType::Local, MemoryType::Local);
402        assert_ne!(MemoryType::Local, MemoryType::Host);
403        assert_eq!(MemoryType::Shared, MemoryType::Shared);
404    }
405}