1use std::collections::HashSet;
4
5use crate::ops::{ElemOp, ReduceOp};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub enum DeviceType {
10 CPU,
11 GPU,
12 TPU,
13 Custom(u32),
14}
15
16impl DeviceType {
17 pub fn as_str(&self) -> &str {
18 match self {
19 DeviceType::CPU => "CPU",
20 DeviceType::GPU => "GPU",
21 DeviceType::TPU => "TPU",
22 DeviceType::Custom(_) => "Custom",
23 }
24 }
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29pub enum DType {
30 F32,
31 F64,
32 I32,
33 I64,
34 Bool,
35 Custom(u32),
36}
37
38impl DType {
39 pub fn as_str(&self) -> &str {
40 match self {
41 DType::F32 => "f32",
42 DType::F64 => "f64",
43 DType::I32 => "i32",
44 DType::I64 => "i64",
45 DType::Bool => "bool",
46 DType::Custom(_) => "custom",
47 }
48 }
49
50 pub fn byte_size(&self) -> usize {
51 match self {
52 DType::F32 | DType::I32 => 4,
53 DType::F64 | DType::I64 => 8,
54 DType::Bool => 1,
55 DType::Custom(_) => 0,
56 }
57 }
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
62pub enum Feature {
63 Autodiff,
65 BatchExecution,
67 SparseTensors,
69 MixedPrecision,
71 SIMDAcceleration,
73 GPUAcceleration,
75 DistributedExecution,
77 JIT,
79 Custom(u32),
81}
82
83impl Feature {
84 pub fn as_str(&self) -> &str {
85 match self {
86 Feature::Autodiff => "Autodiff",
87 Feature::BatchExecution => "BatchExecution",
88 Feature::SparseTensors => "SparseTensors",
89 Feature::MixedPrecision => "MixedPrecision",
90 Feature::SIMDAcceleration => "SIMDAcceleration",
91 Feature::GPUAcceleration => "GPUAcceleration",
92 Feature::DistributedExecution => "DistributedExecution",
93 Feature::JIT => "JIT",
94 Feature::Custom(_) => "Custom",
95 }
96 }
97}
98
99#[derive(Debug, Clone)]
101pub struct BackendCapabilities {
102 pub name: String,
103 pub version: String,
104 pub supported_devices: HashSet<DeviceType>,
105 pub supported_dtypes: HashSet<DType>,
106 pub features: HashSet<Feature>,
107 pub max_tensor_dims: usize,
108 pub max_tensor_size: Option<usize>,
109}
110
111impl BackendCapabilities {
112 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
113 BackendCapabilities {
114 name: name.into(),
115 version: version.into(),
116 supported_devices: HashSet::new(),
117 supported_dtypes: HashSet::new(),
118 features: HashSet::new(),
119 max_tensor_dims: 8, max_tensor_size: None,
121 }
122 }
123
124 pub fn with_device(mut self, device: DeviceType) -> Self {
125 self.supported_devices.insert(device);
126 self
127 }
128
129 pub fn with_dtype(mut self, dtype: DType) -> Self {
130 self.supported_dtypes.insert(dtype);
131 self
132 }
133
134 pub fn with_feature(mut self, feature: Feature) -> Self {
135 self.features.insert(feature);
136 self
137 }
138
139 pub fn with_max_dims(mut self, max_dims: usize) -> Self {
140 self.max_tensor_dims = max_dims;
141 self
142 }
143
144 pub fn supports_device(&self, device: DeviceType) -> bool {
145 self.supported_devices.contains(&device)
146 }
147
148 pub fn supports_dtype(&self, dtype: DType) -> bool {
149 self.supported_dtypes.contains(&dtype)
150 }
151
152 pub fn supports_feature(&self, feature: Feature) -> bool {
153 self.features.contains(&feature)
154 }
155
156 pub fn can_execute_on(&self, device: DeviceType, dtype: DType) -> bool {
157 self.supports_device(device) && self.supports_dtype(dtype)
158 }
159
160 pub fn summary(&self) -> String {
162 let mut summary = String::new();
163 summary.push_str(&format!("Backend: {} v{}\n", self.name, self.version));
164 summary.push_str("Devices: ");
165 for device in &self.supported_devices {
166 summary.push_str(&format!("{} ", device.as_str()));
167 }
168 summary.push('\n');
169 summary.push_str("Data Types: ");
170 for dtype in &self.supported_dtypes {
171 summary.push_str(&format!("{} ", dtype.as_str()));
172 }
173 summary.push('\n');
174 summary.push_str("Features: ");
175 for feature in &self.features {
176 summary.push_str(&format!("{} ", feature.as_str()));
177 }
178 summary.push('\n');
179 summary.push_str(&format!("Max Tensor Dims: {}\n", self.max_tensor_dims));
180 summary
181 }
182}
183
184pub trait TlCapabilities {
186 fn capabilities(&self) -> &BackendCapabilities;
188
189 fn supports_elem_op(&self, op: ElemOp) -> bool {
191 let _ = op;
192 true }
194
195 fn supports_reduce_op(&self, op: ReduceOp) -> bool {
197 let _ = op;
198 true }
200
201 fn supports_einsum(&self, spec: &str) -> bool {
203 let _ = spec;
204 true }
206
207 fn available_devices(&self) -> Vec<DeviceType> {
209 self.capabilities()
210 .supported_devices
211 .iter()
212 .copied()
213 .collect()
214 }
215
216 fn default_device(&self) -> DeviceType {
218 self.available_devices()
219 .first()
220 .copied()
221 .unwrap_or(DeviceType::CPU)
222 }
223}
224
225#[cfg(test)]
226mod tests {
227 use super::*;
228
229 #[test]
230 fn test_device_type() {
231 let cpu = DeviceType::CPU;
232 assert_eq!(cpu.as_str(), "CPU");
233
234 let gpu = DeviceType::GPU;
235 assert_eq!(gpu.as_str(), "GPU");
236 }
237
238 #[test]
239 fn test_dtype() {
240 let f32 = DType::F32;
241 assert_eq!(f32.as_str(), "f32");
242 assert_eq!(f32.byte_size(), 4);
243
244 let f64 = DType::F64;
245 assert_eq!(f64.byte_size(), 8);
246 }
247
248 #[test]
249 fn test_feature() {
250 let autodiff = Feature::Autodiff;
251 assert_eq!(autodiff.as_str(), "Autodiff");
252 }
253
254 #[test]
255 fn test_backend_capabilities() {
256 let caps = BackendCapabilities::new("TestBackend", "1.0")
257 .with_device(DeviceType::CPU)
258 .with_device(DeviceType::GPU)
259 .with_dtype(DType::F32)
260 .with_dtype(DType::F64)
261 .with_feature(Feature::Autodiff)
262 .with_max_dims(10);
263
264 assert!(caps.supports_device(DeviceType::CPU));
265 assert!(caps.supports_device(DeviceType::GPU));
266 assert!(!caps.supports_device(DeviceType::TPU));
267
268 assert!(caps.supports_dtype(DType::F32));
269 assert!(!caps.supports_dtype(DType::I32));
270
271 assert!(caps.supports_feature(Feature::Autodiff));
272 assert!(!caps.supports_feature(Feature::BatchExecution));
273
274 assert_eq!(caps.max_tensor_dims, 10);
275 }
276
277 #[test]
278 fn test_can_execute_on() {
279 let caps = BackendCapabilities::new("TestBackend", "1.0")
280 .with_device(DeviceType::CPU)
281 .with_dtype(DType::F32);
282
283 assert!(caps.can_execute_on(DeviceType::CPU, DType::F32));
284 assert!(!caps.can_execute_on(DeviceType::GPU, DType::F32));
285 assert!(!caps.can_execute_on(DeviceType::CPU, DType::F64));
286 }
287
288 #[test]
289 fn test_capabilities_summary() {
290 let caps = BackendCapabilities::new("TestBackend", "1.0")
291 .with_device(DeviceType::CPU)
292 .with_dtype(DType::F32)
293 .with_feature(Feature::Autodiff);
294
295 let summary = caps.summary();
296 assert!(summary.contains("TestBackend"));
297 assert!(summary.contains("1.0"));
298 assert!(summary.contains("CPU"));
299 assert!(summary.contains("f32"));
300 assert!(summary.contains("Autodiff"));
301 }
302}