Skip to main content

tensorlogic_infer/
capabilities.rs

1//! Backend capability queries and feature detection.
2
3use std::collections::HashSet;
4
5use crate::ops::{ElemOp, ReduceOp};
6
7/// Device types that a backend can execute on
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum DeviceType {
10    CPU,
11    GPU,
12    TPU,
13    Custom(u32),
14}
15
16impl DeviceType {
17    pub fn as_str(&self) -> &str {
18        match self {
19            DeviceType::CPU => "CPU",
20            DeviceType::GPU => "GPU",
21            DeviceType::TPU => "TPU",
22            DeviceType::Custom(_) => "Custom",
23        }
24    }
25}
26
27/// Precision/data type support
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum DType {
30    F32,
31    F64,
32    I32,
33    I64,
34    Bool,
35    Custom(u32),
36}
37
38impl DType {
39    pub fn as_str(&self) -> &str {
40        match self {
41            DType::F32 => "f32",
42            DType::F64 => "f64",
43            DType::I32 => "i32",
44            DType::I64 => "i64",
45            DType::Bool => "bool",
46            DType::Custom(_) => "custom",
47        }
48    }
49
50    pub fn byte_size(&self) -> usize {
51        match self {
52            DType::F32 | DType::I32 => 4,
53            DType::F64 | DType::I64 => 8,
54            DType::Bool => 1,
55            DType::Custom(_) => 0,
56        }
57    }
58}
59
60/// Backend feature flags
61#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub enum Feature {
63    /// Supports automatic differentiation
64    Autodiff,
65    /// Supports batched execution
66    BatchExecution,
67    /// Supports sparse tensors
68    SparseTensors,
69    /// Supports mixed precision
70    MixedPrecision,
71    /// Supports SIMD acceleration
72    SIMDAcceleration,
73    /// Supports GPU execution
74    GPUAcceleration,
75    /// Supports distributed execution
76    DistributedExecution,
77    /// Supports JIT compilation
78    JIT,
79    /// Custom feature
80    Custom(u32),
81}
82
83impl Feature {
84    pub fn as_str(&self) -> &str {
85        match self {
86            Feature::Autodiff => "Autodiff",
87            Feature::BatchExecution => "BatchExecution",
88            Feature::SparseTensors => "SparseTensors",
89            Feature::MixedPrecision => "MixedPrecision",
90            Feature::SIMDAcceleration => "SIMDAcceleration",
91            Feature::GPUAcceleration => "GPUAcceleration",
92            Feature::DistributedExecution => "DistributedExecution",
93            Feature::JIT => "JIT",
94            Feature::Custom(_) => "Custom",
95        }
96    }
97}
98
99/// Backend capabilities descriptor
100#[derive(Debug, Clone)]
101pub struct BackendCapabilities {
102    pub name: String,
103    pub version: String,
104    pub supported_devices: HashSet<DeviceType>,
105    pub supported_dtypes: HashSet<DType>,
106    pub features: HashSet<Feature>,
107    pub max_tensor_dims: usize,
108    pub max_tensor_size: Option<usize>,
109}
110
111impl BackendCapabilities {
112    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
113        BackendCapabilities {
114            name: name.into(),
115            version: version.into(),
116            supported_devices: HashSet::new(),
117            supported_dtypes: HashSet::new(),
118            features: HashSet::new(),
119            max_tensor_dims: 8, // Default max rank
120            max_tensor_size: None,
121        }
122    }
123
124    pub fn with_device(mut self, device: DeviceType) -> Self {
125        self.supported_devices.insert(device);
126        self
127    }
128
129    pub fn with_dtype(mut self, dtype: DType) -> Self {
130        self.supported_dtypes.insert(dtype);
131        self
132    }
133
134    pub fn with_feature(mut self, feature: Feature) -> Self {
135        self.features.insert(feature);
136        self
137    }
138
139    pub fn with_max_dims(mut self, max_dims: usize) -> Self {
140        self.max_tensor_dims = max_dims;
141        self
142    }
143
144    pub fn supports_device(&self, device: DeviceType) -> bool {
145        self.supported_devices.contains(&device)
146    }
147
148    pub fn supports_dtype(&self, dtype: DType) -> bool {
149        self.supported_dtypes.contains(&dtype)
150    }
151
152    pub fn supports_feature(&self, feature: Feature) -> bool {
153        self.features.contains(&feature)
154    }
155
156    pub fn can_execute_on(&self, device: DeviceType, dtype: DType) -> bool {
157        self.supports_device(device) && self.supports_dtype(dtype)
158    }
159
160    /// Generate a summary of capabilities
161    pub fn summary(&self) -> String {
162        let mut summary = String::new();
163        summary.push_str(&format!("Backend: {} v{}\n", self.name, self.version));
164        summary.push_str("Devices: ");
165        for device in &self.supported_devices {
166            summary.push_str(&format!("{} ", device.as_str()));
167        }
168        summary.push('\n');
169        summary.push_str("Data Types: ");
170        for dtype in &self.supported_dtypes {
171            summary.push_str(&format!("{} ", dtype.as_str()));
172        }
173        summary.push('\n');
174        summary.push_str("Features: ");
175        for feature in &self.features {
176            summary.push_str(&format!("{} ", feature.as_str()));
177        }
178        summary.push('\n');
179        summary.push_str(&format!("Max Tensor Dims: {}\n", self.max_tensor_dims));
180        summary
181    }
182}
183
184/// Trait for backends to advertise their capabilities
185pub trait TlCapabilities {
186    /// Get backend capabilities
187    fn capabilities(&self) -> &BackendCapabilities;
188
189    /// Check if a specific operation is supported
190    fn supports_elem_op(&self, op: ElemOp) -> bool {
191        let _ = op;
192        true // Default: support all ops
193    }
194
195    /// Check if a specific reduction operation is supported
196    fn supports_reduce_op(&self, op: ReduceOp) -> bool {
197        let _ = op;
198        true // Default: support all ops
199    }
200
201    /// Check if einsum is supported with the given spec
202    fn supports_einsum(&self, spec: &str) -> bool {
203        let _ = spec;
204        true // Default: support all einsum specs
205    }
206
207    /// Get available devices
208    fn available_devices(&self) -> Vec<DeviceType> {
209        self.capabilities()
210            .supported_devices
211            .iter()
212            .copied()
213            .collect()
214    }
215
216    /// Get default device
217    fn default_device(&self) -> DeviceType {
218        self.available_devices()
219            .first()
220            .copied()
221            .unwrap_or(DeviceType::CPU)
222    }
223}
224
225#[cfg(test)]
226mod tests {
227    use super::*;
228
229    #[test]
230    fn test_device_type() {
231        let cpu = DeviceType::CPU;
232        assert_eq!(cpu.as_str(), "CPU");
233
234        let gpu = DeviceType::GPU;
235        assert_eq!(gpu.as_str(), "GPU");
236    }
237
238    #[test]
239    fn test_dtype() {
240        let f32 = DType::F32;
241        assert_eq!(f32.as_str(), "f32");
242        assert_eq!(f32.byte_size(), 4);
243
244        let f64 = DType::F64;
245        assert_eq!(f64.byte_size(), 8);
246    }
247
248    #[test]
249    fn test_feature() {
250        let autodiff = Feature::Autodiff;
251        assert_eq!(autodiff.as_str(), "Autodiff");
252    }
253
254    #[test]
255    fn test_backend_capabilities() {
256        let caps = BackendCapabilities::new("TestBackend", "1.0")
257            .with_device(DeviceType::CPU)
258            .with_device(DeviceType::GPU)
259            .with_dtype(DType::F32)
260            .with_dtype(DType::F64)
261            .with_feature(Feature::Autodiff)
262            .with_max_dims(10);
263
264        assert!(caps.supports_device(DeviceType::CPU));
265        assert!(caps.supports_device(DeviceType::GPU));
266        assert!(!caps.supports_device(DeviceType::TPU));
267
268        assert!(caps.supports_dtype(DType::F32));
269        assert!(!caps.supports_dtype(DType::I32));
270
271        assert!(caps.supports_feature(Feature::Autodiff));
272        assert!(!caps.supports_feature(Feature::BatchExecution));
273
274        assert_eq!(caps.max_tensor_dims, 10);
275    }
276
277    #[test]
278    fn test_can_execute_on() {
279        let caps = BackendCapabilities::new("TestBackend", "1.0")
280            .with_device(DeviceType::CPU)
281            .with_dtype(DType::F32);
282
283        assert!(caps.can_execute_on(DeviceType::CPU, DType::F32));
284        assert!(!caps.can_execute_on(DeviceType::GPU, DType::F32));
285        assert!(!caps.can_execute_on(DeviceType::CPU, DType::F64));
286    }
287
288    #[test]
289    fn test_capabilities_summary() {
290        let caps = BackendCapabilities::new("TestBackend", "1.0")
291            .with_device(DeviceType::CPU)
292            .with_dtype(DType::F32)
293            .with_feature(Feature::Autodiff);
294
295        let summary = caps.summary();
296        assert!(summary.contains("TestBackend"));
297        assert!(summary.contains("1.0"));
298        assert!(summary.contains("CPU"));
299        assert!(summary.contains("f32"));
300        assert!(summary.contains("Autodiff"));
301    }
302}