1use crate::{DType, Result, Shape, TorshError};
8use std::collections::HashMap;
9
10pub trait FromExternal<T> {
12 fn from_external(value: T) -> Result<Self>
14 where
15 Self: Sized;
16}
17
18pub trait ToExternal<T> {
20 fn to_external(&self) -> Result<T>;
22}
23
24pub trait FromExternalZeroCopy<T> {
26 fn from_external_zero_copy(value: T) -> Result<Self>
28 where
29 Self: Sized;
30
31 fn can_zero_copy(value: &T) -> bool;
33}
34
35pub trait ToExternalZeroCopy<T> {
37 fn to_external_zero_copy(&self) -> Result<T>;
39
40 fn can_zero_copy(&self) -> bool;
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct NumpyArrayInfo {
47 pub shape: Vec<usize>,
49 pub strides: Vec<isize>,
51 pub dtype: DType,
53 pub c_contiguous: bool,
55 pub f_contiguous: bool,
57 pub nbytes: usize,
59}
60
61impl NumpyArrayInfo {
62 pub fn new(shape: Vec<usize>, dtype: DType) -> Self {
64 let strides = Self::compute_c_strides(&shape, dtype.size());
65 let nbytes = shape.iter().product::<usize>() * dtype.size();
66
67 Self {
68 c_contiguous: true,
69 f_contiguous: shape.len() <= 1,
70 shape,
71 strides,
72 dtype,
73 nbytes,
74 }
75 }
76
77 pub fn with_strides(shape: Vec<usize>, strides: Vec<isize>, dtype: DType) -> Self {
79 let nbytes = shape.iter().product::<usize>() * dtype.size();
80 let c_strides = Self::compute_c_strides(&shape, dtype.size());
81 let f_strides = Self::compute_f_strides(&shape, dtype.size());
82
83 Self {
84 shape,
85 strides: strides.clone(),
86 dtype,
87 c_contiguous: strides == c_strides,
88 f_contiguous: strides == f_strides,
89 nbytes,
90 }
91 }
92
93 fn compute_c_strides(shape: &[usize], itemsize: usize) -> Vec<isize> {
95 let mut strides = vec![0; shape.len()];
96 if !shape.is_empty() {
97 let mut stride = itemsize as isize;
98 for i in (0..shape.len()).rev() {
99 strides[i] = stride;
100 stride *= shape[i] as isize;
101 }
102 }
103 strides
104 }
105
106 fn compute_f_strides(shape: &[usize], itemsize: usize) -> Vec<isize> {
108 let mut strides = vec![0; shape.len()];
109 if !shape.is_empty() {
110 let mut stride = itemsize as isize;
111 for i in 0..shape.len() {
112 strides[i] = stride;
113 stride *= shape[i] as isize;
114 }
115 }
116 strides
117 }
118}
119
120#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct OnnxTensorInfo {
123 pub elem_type: OnnxDataType,
125 pub shape: Vec<Option<usize>>,
127 pub name: Option<String>,
129}
130
131#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
133pub enum OnnxDataType {
134 Undefined = 0,
136 Float = 1,
138 Uint8 = 2,
140 Int8 = 3,
142 Uint16 = 4,
144 Int16 = 5,
146 Int32 = 6,
148 Int64 = 7,
150 String = 8,
152 Bool = 9,
154 Float16 = 10,
156 Double = 11,
158 Uint32 = 12,
160 Uint64 = 13,
162 Complex64 = 14,
164 Complex128 = 15,
166 Bfloat16 = 16,
168}
169
170impl From<DType> for OnnxDataType {
171 fn from(dtype: DType) -> Self {
172 match dtype {
173 DType::F32 => OnnxDataType::Float,
174 DType::F64 => OnnxDataType::Double,
175 DType::F16 => OnnxDataType::Float16,
176 DType::BF16 => OnnxDataType::Bfloat16,
177 DType::I8 => OnnxDataType::Int8,
178 DType::U8 => OnnxDataType::Uint8,
179 DType::I16 => OnnxDataType::Int16,
180 DType::I32 => OnnxDataType::Int32,
181 DType::I64 => OnnxDataType::Int64,
182 DType::U32 => OnnxDataType::Uint32,
183 DType::U64 => OnnxDataType::Uint64,
184 DType::Bool => OnnxDataType::Bool,
185 DType::C64 => OnnxDataType::Complex64,
186 DType::C128 => OnnxDataType::Complex128,
187 DType::QInt8 => OnnxDataType::Int8, DType::QUInt8 => OnnxDataType::Uint8,
189 DType::QInt32 => OnnxDataType::Int32, }
191 }
192}
193
194impl TryFrom<OnnxDataType> for DType {
195 type Error = TorshError;
196
197 fn try_from(onnx_type: OnnxDataType) -> Result<Self> {
198 match onnx_type {
199 OnnxDataType::Float => Ok(DType::F32),
200 OnnxDataType::Double => Ok(DType::F64),
201 OnnxDataType::Float16 => Ok(DType::F16),
202 OnnxDataType::Bfloat16 => Ok(DType::BF16),
203 OnnxDataType::Int8 => Ok(DType::I8),
204 OnnxDataType::Uint8 => Ok(DType::U8),
205 OnnxDataType::Int16 => Ok(DType::I16),
206 OnnxDataType::Int32 => Ok(DType::I32),
207 OnnxDataType::Int64 => Ok(DType::I64),
208 OnnxDataType::Bool => Ok(DType::Bool),
209 OnnxDataType::Complex64 => Ok(DType::C64),
210 OnnxDataType::Complex128 => Ok(DType::C128),
211 _ => Err(TorshError::UnsupportedOperation {
212 op: "ONNX data type conversion".to_string(),
213 dtype: format!("{onnx_type:?}"),
214 }),
215 }
216 }
217}
218
219#[derive(Debug, Clone, PartialEq, Eq)]
221pub struct ArrowTypeInfo {
222 pub data_type: ArrowDataType,
224 pub metadata: HashMap<String, String>,
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
230pub enum ArrowDataType {
231 Boolean,
233 Int8,
235 Int16,
237 Int32,
239 Int64,
241 UInt8,
243 UInt16,
245 UInt32,
247 UInt64,
249 Float16,
251 Float32,
253 Float64,
255 FixedSizeList(Box<ArrowDataType>, usize),
257}
258
259impl From<DType> for ArrowDataType {
260 fn from(dtype: DType) -> Self {
261 match dtype {
262 DType::Bool => ArrowDataType::Boolean,
263 DType::I8 | DType::QInt8 => ArrowDataType::Int8,
264 DType::U8 | DType::QUInt8 => ArrowDataType::UInt8,
265 DType::I16 => ArrowDataType::Int16,
266 DType::I32 | DType::QInt32 => ArrowDataType::Int32,
267 DType::I64 => ArrowDataType::Int64,
268 DType::U32 => ArrowDataType::UInt32,
269 DType::U64 => ArrowDataType::UInt64,
270 DType::F16 => ArrowDataType::Float16,
271 DType::F32 => ArrowDataType::Float32,
272 DType::F64 => ArrowDataType::Float64,
273 DType::BF16 => ArrowDataType::Float32, DType::C64 => ArrowDataType::FixedSizeList(Box::new(ArrowDataType::Float32), 2),
275 DType::C128 => ArrowDataType::FixedSizeList(Box::new(ArrowDataType::Float64), 2),
276 }
277 }
278}
279
280pub struct ConversionUtils;
282
283impl ConversionUtils {
284 pub fn torsh_shape_to_numpy(shape: &Shape) -> Vec<usize> {
286 shape.dims().to_vec()
287 }
288
289 pub fn numpy_shape_to_torsh(shape: Vec<usize>) -> Result<Shape> {
291 Ok(Shape::new(shape))
292 }
293
294 pub fn is_layout_compatible(
296 shape1: &[usize],
297 strides1: &[isize],
298 shape2: &[usize],
299 strides2: &[isize],
300 ) -> bool {
301 if shape1.len() != shape2.len() || shape1 != shape2 {
302 return false;
303 }
304
305 strides1 == strides2
306 }
307
308 pub fn layout_efficiency_score(shape: &[usize], strides: &[isize], itemsize: usize) -> f64 {
310 if shape.is_empty() {
311 return 1.0;
312 }
313
314 let c_strides = NumpyArrayInfo::compute_c_strides(shape, itemsize);
316 if strides == c_strides {
317 return 1.0;
318 }
319
320 let f_strides = NumpyArrayInfo::compute_f_strides(shape, itemsize);
322 if strides == f_strides {
323 return 0.9;
324 }
325
326 let total_elements: usize = shape.iter().product();
328 let expected_size = total_elements * itemsize;
329 let actual_span = Self::compute_memory_span(shape, strides, itemsize);
330
331 if actual_span == 0 {
332 return 0.0;
333 }
334
335 (expected_size as f64 / actual_span as f64).min(1.0)
336 }
337
338 fn compute_memory_span(shape: &[usize], strides: &[isize], itemsize: usize) -> usize {
340 if shape.is_empty() {
341 return 0;
342 }
343
344 let mut min_offset = 0isize;
345 let mut max_offset = 0isize;
346
347 for (&dim, &stride) in shape.iter().zip(strides.iter()) {
348 if dim > 1 {
349 let offset = stride * (dim as isize - 1);
350 min_offset = min_offset.min(offset);
351 max_offset = max_offset.max(offset);
352 }
353 }
354
355 (max_offset - min_offset) as usize + itemsize
356 }
357}
358
359pub struct InteropDocs;
361
362impl InteropDocs {
363 pub fn supported_conversions() -> String {
365 let conversions = vec![
366 ("NumPy", "ndarray", "Zero-copy when C-contiguous"),
367 ("ndarray", "Array", "Zero-copy when contiguous"),
368 ("ONNX", "TensorProto", "Schema mapping"),
369 ("Arrow", "Array", "Type mapping with metadata"),
370 ("Rust", "Vec<T>", "Direct conversion"),
371 ];
372
373 let mut doc = String::from("Supported Tensor Format Conversions:\n");
374 doc.push_str("=========================================\n\n");
375
376 for (from, to, notes) in conversions {
377 doc.push_str(&format!("• {from} ↔ {to}: {notes}\n"));
378 }
379
380 doc
381 }
382
383 pub fn conversion_examples() -> String {
385 r#"
386Conversion Examples:
387==================
388
389// NumPy-style array info
390let numpy_info = NumpyArrayInfo::new(vec![2, 3, 4], DType::F32);
391assert!(numpy_info.c_contiguous);
392
393// ONNX type conversion
394let onnx_type = OnnxDataType::from(DType::F32);
395let back_to_dtype = DType::try_from(onnx_type).unwrap();
396
397// Arrow type conversion
398let arrow_type = ArrowDataType::from(DType::C64);
399match arrow_type {
400 ArrowDataType::FixedSizeList(inner, size) => {
401 assert_eq!(size, 2); // Real and imaginary parts
402 }
403 _ => panic!("Unexpected type"),
404}
405
406// Layout efficiency checking
407let shape = vec![1000, 1000];
408let c_strides = vec![4000, 4]; // C-contiguous for f32
409let efficiency = ConversionUtils::layout_efficiency_score(&shape, &c_strides, 4);
410assert_eq!(efficiency, 1.0);
411"#
412 .to_string()
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use super::*;
419
420 #[test]
421 fn test_numpy_array_info() {
422 let info = NumpyArrayInfo::new(vec![2, 3, 4], DType::F32);
423 assert_eq!(info.shape, vec![2, 3, 4]);
424 assert_eq!(info.strides, vec![48, 16, 4]); assert!(info.c_contiguous);
426 assert!(!info.f_contiguous);
427 assert_eq!(info.nbytes, 96); }
429
430 #[test]
431 fn test_onnx_dtype_conversion() {
432 let dtypes = vec![
434 DType::F32,
435 DType::F64,
436 DType::I8,
437 DType::U8,
438 DType::I32,
439 DType::Bool,
440 DType::C64,
441 ];
442
443 for dtype in dtypes {
444 let onnx_type = OnnxDataType::from(dtype);
445 let back_to_dtype = DType::try_from(onnx_type).expect("try_from should succeed");
446 assert_eq!(dtype, back_to_dtype);
447 }
448 }
449
450 #[test]
451 fn test_arrow_dtype_conversion() {
452 assert_eq!(ArrowDataType::from(DType::F32), ArrowDataType::Float32);
453 assert_eq!(ArrowDataType::from(DType::Bool), ArrowDataType::Boolean);
454
455 match ArrowDataType::from(DType::C64) {
457 ArrowDataType::FixedSizeList(inner, size) => {
458 assert_eq!(*inner, ArrowDataType::Float32);
459 assert_eq!(size, 2);
460 }
461 _ => panic!("Expected FixedSizeList for C64"),
462 }
463 }
464
465 #[test]
466 fn test_layout_efficiency() {
467 let shape = vec![10, 10];
468 let itemsize = 4;
469
470 let c_strides = vec![40, 4];
472 let efficiency = ConversionUtils::layout_efficiency_score(&shape, &c_strides, itemsize);
473 assert_eq!(efficiency, 1.0);
474
475 let f_strides = vec![4, 40];
477 let efficiency = ConversionUtils::layout_efficiency_score(&shape, &f_strides, itemsize);
478 assert_eq!(efficiency, 0.9);
479 }
480
481 #[test]
482 fn test_conversion_utils() {
483 let shape = Shape::new(vec![2, 3, 4]);
484 let numpy_shape = ConversionUtils::torsh_shape_to_numpy(&shape);
485 assert_eq!(numpy_shape, vec![2, 3, 4]);
486
487 let back_to_shape = ConversionUtils::numpy_shape_to_torsh(numpy_shape)
488 .expect("numpy_shape_to_torsh should succeed");
489 assert_eq!(shape.dims(), back_to_shape.dims());
490 }
491
492 #[test]
493 fn test_layout_compatibility() {
494 let shape1 = vec![2, 3];
495 let strides1 = vec![12, 4];
496 let shape2 = vec![2, 3];
497 let strides2 = vec![12, 4];
498
499 assert!(ConversionUtils::is_layout_compatible(
500 &shape1, &strides1, &shape2, &strides2
501 ));
502
503 let strides3 = vec![4, 8]; assert!(!ConversionUtils::is_layout_compatible(
505 &shape1, &strides1, &shape2, &strides3
506 ));
507 }
508}