1use super::{HardwareCapabilities, HardwareConfig, HardwareMetrics, HardwareResult, HardwareType};
11use crate::tensor::Tensor;
12use async_trait::async_trait;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15
16#[async_trait]
18pub trait HardwareDevice: Send + Sync {
19 fn device_id(&self) -> &str;
21
22 fn hardware_type(&self) -> HardwareType;
24
25 fn capabilities(&self) -> &HardwareCapabilities;
27
28 async fn initialize(&mut self, config: &HardwareConfig) -> HardwareResult<()>;
30
31 async fn shutdown(&mut self) -> HardwareResult<()>;
33
34 fn is_available(&self) -> bool;
36
37 fn status(&self) -> DeviceStatus;
39
40 async fn metrics(&self) -> HardwareResult<HardwareMetrics>;
42
43 async fn reset(&mut self) -> HardwareResult<()>;
45
46 async fn allocate_memory(&mut self, size: usize) -> HardwareResult<DeviceMemory>;
48
49 async fn free_memory(&mut self, memory: DeviceMemory) -> HardwareResult<()>;
51
52 async fn synchronize(&self) -> HardwareResult<()>;
54}
55
56#[async_trait]
58pub trait HardwareBackend: Send + Sync {
59 fn name(&self) -> &str;
61
62 fn version(&self) -> &str;
64
65 async fn discover_devices(&self) -> HardwareResult<Vec<Box<dyn HardwareDevice>>>;
67
68 async fn create_device(
70 &self,
71 config: &HardwareConfig,
72 ) -> HardwareResult<Box<dyn HardwareDevice>>;
73
74 fn is_compatible(&self, hardware_type: HardwareType) -> bool;
76
77 fn supported_operations(&self) -> &[String];
79
80 fn validate_config(&self, config: &HardwareConfig) -> HardwareResult<()>;
82}
83
84#[async_trait]
86pub trait HardwareOperation: Send + Sync {
87 fn name(&self) -> &str;
89
90 async fn execute(
92 &self,
93 device: &mut dyn HardwareDevice,
94 inputs: &[Tensor],
95 outputs: &mut [Tensor],
96 params: &HashMap<String, OperationParameter>,
97 ) -> HardwareResult<()>;
98
99 fn validate_params(&self, params: &HashMap<String, OperationParameter>) -> HardwareResult<()>;
101
102 fn requirements(&self) -> OperationRequirements;
104
105 fn estimate_cost(&self, inputs: &[Tensor], params: &HashMap<String, OperationParameter>)
107 -> f64;
108}
109
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
112pub struct DeviceStatus {
113 pub online: bool,
115 pub busy: bool,
117 pub error: Option<String>,
119 pub memory_usage: MemoryUsage,
121 pub temperature: Option<f64>,
123 pub power_consumption: Option<f64>,
125 pub utilization: f64,
127}
128
129#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
131pub struct MemoryUsage {
132 pub total: usize,
134 pub used: usize,
136 pub free: usize,
138 pub fragmentation: f64,
140}
141
142#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
144pub struct DeviceMemory {
145 pub address: usize,
147 pub size: usize,
149 pub memory_type: MemoryType,
151 pub device_id: String,
153}
154
155#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
157pub enum MemoryType {
158 Local,
160 Host,
162 Shared,
164 Unified,
166 Persistent,
168 Cache,
170}
171
172#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
174pub enum OperationParameter {
175 Integer(i64),
177 Float(f64),
179 String(String),
181 Boolean(bool),
183 Array(Vec<OperationParameter>),
185 Object(HashMap<String, OperationParameter>),
187}
188
189#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
191pub struct OperationRequirements {
192 pub min_memory: usize,
194 pub compute_units: Option<u32>,
196 pub data_types: Vec<super::DataType>,
198 pub capabilities: Vec<String>,
200 pub performance: PerformanceRequirements,
202}
203
204#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
206pub struct PerformanceRequirements {
207 pub max_latency: Option<f64>,
209 pub min_throughput: Option<f64>,
211 pub memory_bandwidth: Option<f64>,
213 pub power_limit: Option<f64>,
215}
216
217#[async_trait]
219pub trait AsyncHardwareOperation: Send + Sync {
220 async fn start(
222 &self,
223 device: &mut dyn HardwareDevice,
224 inputs: &[Tensor],
225 params: &HashMap<String, OperationParameter>,
226 ) -> HardwareResult<AsyncOperationHandle>;
227
228 async fn status(&self, handle: &AsyncOperationHandle) -> HardwareResult<AsyncOperationStatus>;
230
231 async fn results(&self, handle: &AsyncOperationHandle) -> HardwareResult<Vec<Tensor>>;
233
234 async fn cancel(&self, handle: &AsyncOperationHandle) -> HardwareResult<()>;
236}
237
238#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
240pub struct AsyncOperationHandle {
241 pub id: String,
243 pub device_id: String,
245 pub operation_name: String,
247 pub start_time: std::time::SystemTime,
249}
250
251#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
253pub enum AsyncOperationStatus {
254 Queued,
256 Running,
258 Completed,
260 Failed(String),
262 Cancelled,
264}
265
266pub trait HardwareScheduler: Send + Sync + std::fmt::Debug {
268 fn schedule_operation(
270 &self,
271 operation: &dyn HardwareOperation,
272 inputs: &[Tensor],
273 params: &HashMap<String, OperationParameter>,
274 ) -> HardwareResult<String>; fn statistics(&self) -> SchedulerStatistics;
278
279 fn update_priorities(&mut self, priorities: HashMap<String, f64>);
281}
282
283#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
285pub struct SchedulerStatistics {
286 pub total_operations: u64,
288 pub operations_per_device: HashMap<String, u64>,
290 pub avg_scheduling_time: f64,
292 pub device_utilization: HashMap<String, f64>,
294 pub failed_operations: u64,
296}
297
298impl Default for DeviceStatus {
299 fn default() -> Self {
300 Self {
301 online: false,
302 busy: false,
303 error: None,
304 memory_usage: MemoryUsage::default(),
305 temperature: None,
306 power_consumption: None,
307 utilization: 0.0,
308 }
309 }
310}
311
312impl Default for MemoryUsage {
313 fn default() -> Self {
314 Self {
315 total: 0,
316 used: 0,
317 free: 0,
318 fragmentation: 0.0,
319 }
320 }
321}
322
323impl Default for OperationRequirements {
324 fn default() -> Self {
325 Self {
326 min_memory: 0,
327 compute_units: None,
328 data_types: vec![super::DataType::F32],
329 capabilities: vec![],
330 performance: PerformanceRequirements::default(),
331 }
332 }
333}
334
335impl Default for SchedulerStatistics {
336 fn default() -> Self {
337 Self {
338 total_operations: 0,
339 operations_per_device: HashMap::new(),
340 avg_scheduling_time: 0.0,
341 device_utilization: HashMap::new(),
342 failed_operations: 0,
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use super::*;
350
351 #[test]
352 fn test_device_status_default() {
353 let status = DeviceStatus::default();
354 assert!(!status.online);
355 assert!(!status.busy);
356 assert!(status.error.is_none());
357 assert_eq!(status.utilization, 0.0);
358 }
359
360 #[test]
361 fn test_memory_usage_calculation() {
362 let mut usage = MemoryUsage::default();
363 usage.total = 1000;
364 usage.used = 600;
365 usage.free = usage.total - usage.used;
366 assert_eq!(usage.free, 400);
367 }
368
369 #[test]
370 fn test_operation_parameter_types() {
371 let int_param = OperationParameter::Integer(42);
372 let float_param = OperationParameter::Float(std::f64::consts::PI);
373 let _string_param = OperationParameter::String("test".to_string());
374 let _bool_param = OperationParameter::Boolean(true);
375
376 match int_param {
377 OperationParameter::Integer(val) => assert_eq!(val, 42),
378 _ => panic!("Expected Integer parameter but got {:?}", int_param),
379 }
380
381 match float_param {
382 OperationParameter::Float(val) => assert_eq!(val, std::f64::consts::PI),
383 _ => panic!("Expected Float parameter but got {:?}", float_param),
384 }
385 }
386
387 #[test]
388 fn test_async_operation_status() {
389 let status = AsyncOperationStatus::Queued;
390 assert_eq!(status, AsyncOperationStatus::Queued);
391
392 let failed_status = AsyncOperationStatus::Failed("test error".to_string());
393 match failed_status {
394 AsyncOperationStatus::Failed(msg) => assert_eq!(msg, "test error"),
395 _ => panic!("Expected Failed status but got {:?}", failed_status),
396 }
397 }
398
399 #[test]
400 fn test_memory_type_equality() {
401 assert_eq!(MemoryType::Local, MemoryType::Local);
402 assert_ne!(MemoryType::Local, MemoryType::Host);
403 assert_eq!(MemoryType::Shared, MemoryType::Shared);
404 }
405}