Skip to main content

torsh_core/
interop.rs

1//! Interoperability traits and utilities for ToRSh
2//!
3//! This module provides conversion traits and utilities for interoperating
4//! with other tensor libraries and data formats, including NumPy arrays,
5//! ndarray, Apache Arrow, and standard Rust types.
6
7use crate::{DType, Result, Shape, TorshError};
8use std::collections::HashMap;
9
10/// Trait for converting from external tensor formats to ToRSh types
11pub trait FromExternal<T> {
12    /// Convert from an external type to a ToRSh type
13    fn from_external(value: T) -> Result<Self>
14    where
15        Self: Sized;
16}
17
18/// Trait for converting ToRSh types to external tensor formats
19pub trait ToExternal<T> {
20    /// Convert a ToRSh type to an external type
21    fn to_external(&self) -> Result<T>;
22}
23
24/// Trait for zero-copy conversion from external types when possible
25pub trait FromExternalZeroCopy<T> {
26    /// Attempt zero-copy conversion, falling back to copy if necessary
27    fn from_external_zero_copy(value: T) -> Result<Self>
28    where
29        Self: Sized;
30
31    /// Check if zero-copy conversion is possible for the given value
32    fn can_zero_copy(value: &T) -> bool;
33}
34
35/// Trait for zero-copy conversion to external types when possible
36pub trait ToExternalZeroCopy<T> {
37    /// Attempt zero-copy conversion, falling back to copy if necessary
38    fn to_external_zero_copy(&self) -> Result<T>;
39
40    /// Check if zero-copy conversion is possible
41    fn can_zero_copy(&self) -> bool;
42}
43
44/// NumPy-compatible array metadata
45#[derive(Debug, Clone, PartialEq, Eq)]
46pub struct NumpyArrayInfo {
47    /// Array shape
48    pub shape: Vec<usize>,
49    /// Array strides in bytes
50    pub strides: Vec<isize>,
51    /// Data type
52    pub dtype: DType,
53    /// Whether the array is C-contiguous
54    pub c_contiguous: bool,
55    /// Whether the array is Fortran-contiguous
56    pub f_contiguous: bool,
57    /// Total size in bytes
58    pub nbytes: usize,
59}
60
61impl NumpyArrayInfo {
62    /// Create new NumPy array info
63    pub fn new(shape: Vec<usize>, dtype: DType) -> Self {
64        let strides = Self::compute_c_strides(&shape, dtype.size());
65        let nbytes = shape.iter().product::<usize>() * dtype.size();
66
67        Self {
68            c_contiguous: true,
69            f_contiguous: shape.len() <= 1,
70            shape,
71            strides,
72            dtype,
73            nbytes,
74        }
75    }
76
77    /// Create NumPy array info with custom strides
78    pub fn with_strides(shape: Vec<usize>, strides: Vec<isize>, dtype: DType) -> Self {
79        let nbytes = shape.iter().product::<usize>() * dtype.size();
80        let c_strides = Self::compute_c_strides(&shape, dtype.size());
81        let f_strides = Self::compute_f_strides(&shape, dtype.size());
82
83        Self {
84            shape,
85            strides: strides.clone(),
86            dtype,
87            c_contiguous: strides == c_strides,
88            f_contiguous: strides == f_strides,
89            nbytes,
90        }
91    }
92
93    /// Compute C-contiguous strides
94    fn compute_c_strides(shape: &[usize], itemsize: usize) -> Vec<isize> {
95        let mut strides = vec![0; shape.len()];
96        if !shape.is_empty() {
97            let mut stride = itemsize as isize;
98            for i in (0..shape.len()).rev() {
99                strides[i] = stride;
100                stride *= shape[i] as isize;
101            }
102        }
103        strides
104    }
105
106    /// Compute Fortran-contiguous strides
107    fn compute_f_strides(shape: &[usize], itemsize: usize) -> Vec<isize> {
108        let mut strides = vec![0; shape.len()];
109        if !shape.is_empty() {
110            let mut stride = itemsize as isize;
111            for i in 0..shape.len() {
112                strides[i] = stride;
113                stride *= shape[i] as isize;
114            }
115        }
116        strides
117    }
118}
119
120/// ONNX tensor type information
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct OnnxTensorInfo {
123    /// Element type
124    pub elem_type: OnnxDataType,
125    /// Shape (None for unknown dimensions)
126    pub shape: Vec<Option<usize>>,
127    /// Optional name
128    pub name: Option<String>,
129}
130
131/// ONNX data types
132#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
133pub enum OnnxDataType {
134    /// Undefined
135    Undefined = 0,
136    /// 32-bit floating point
137    Float = 1,
138    /// 8-bit unsigned integer
139    Uint8 = 2,
140    /// 8-bit signed integer
141    Int8 = 3,
142    /// 16-bit unsigned integer
143    Uint16 = 4,
144    /// 16-bit signed integer
145    Int16 = 5,
146    /// 32-bit signed integer
147    Int32 = 6,
148    /// 64-bit signed integer
149    Int64 = 7,
150    /// String
151    String = 8,
152    /// Boolean
153    Bool = 9,
154    /// 16-bit floating point
155    Float16 = 10,
156    /// 64-bit floating point
157    Double = 11,
158    /// 32-bit unsigned integer
159    Uint32 = 12,
160    /// 64-bit unsigned integer
161    Uint64 = 13,
162    /// Complex 64-bit
163    Complex64 = 14,
164    /// Complex 128-bit
165    Complex128 = 15,
166    /// Brain floating point 16-bit
167    Bfloat16 = 16,
168}
169
170impl From<DType> for OnnxDataType {
171    fn from(dtype: DType) -> Self {
172        match dtype {
173            DType::F32 => OnnxDataType::Float,
174            DType::F64 => OnnxDataType::Double,
175            DType::F16 => OnnxDataType::Float16,
176            DType::BF16 => OnnxDataType::Bfloat16,
177            DType::I8 => OnnxDataType::Int8,
178            DType::U8 => OnnxDataType::Uint8,
179            DType::I16 => OnnxDataType::Int16,
180            DType::I32 => OnnxDataType::Int32,
181            DType::I64 => OnnxDataType::Int64,
182            DType::U32 => OnnxDataType::Uint32,
183            DType::U64 => OnnxDataType::Uint64,
184            DType::Bool => OnnxDataType::Bool,
185            DType::C64 => OnnxDataType::Complex64,
186            DType::C128 => OnnxDataType::Complex128,
187            DType::QInt8 => OnnxDataType::Int8, // Quantized types map to base types
188            DType::QUInt8 => OnnxDataType::Uint8,
189            DType::QInt32 => OnnxDataType::Int32, // QInt32 maps to Int32
190        }
191    }
192}
193
194impl TryFrom<OnnxDataType> for DType {
195    type Error = TorshError;
196
197    fn try_from(onnx_type: OnnxDataType) -> Result<Self> {
198        match onnx_type {
199            OnnxDataType::Float => Ok(DType::F32),
200            OnnxDataType::Double => Ok(DType::F64),
201            OnnxDataType::Float16 => Ok(DType::F16),
202            OnnxDataType::Bfloat16 => Ok(DType::BF16),
203            OnnxDataType::Int8 => Ok(DType::I8),
204            OnnxDataType::Uint8 => Ok(DType::U8),
205            OnnxDataType::Int16 => Ok(DType::I16),
206            OnnxDataType::Int32 => Ok(DType::I32),
207            OnnxDataType::Int64 => Ok(DType::I64),
208            OnnxDataType::Bool => Ok(DType::Bool),
209            OnnxDataType::Complex64 => Ok(DType::C64),
210            OnnxDataType::Complex128 => Ok(DType::C128),
211            _ => Err(TorshError::UnsupportedOperation {
212                op: "ONNX data type conversion".to_string(),
213                dtype: format!("{onnx_type:?}"),
214            }),
215        }
216    }
217}
218
219/// Apache Arrow type information
220#[derive(Debug, Clone, PartialEq, Eq)]
221pub struct ArrowTypeInfo {
222    /// Arrow data type
223    pub data_type: ArrowDataType,
224    /// Optional metadata
225    pub metadata: HashMap<String, String>,
226}
227
228/// Simplified Arrow data types
229#[derive(Debug, Clone, PartialEq, Eq)]
230pub enum ArrowDataType {
231    /// Boolean
232    Boolean,
233    /// 8-bit signed integer
234    Int8,
235    /// 16-bit signed integer
236    Int16,
237    /// 32-bit signed integer
238    Int32,
239    /// 64-bit signed integer
240    Int64,
241    /// 8-bit unsigned integer
242    UInt8,
243    /// 16-bit unsigned integer
244    UInt16,
245    /// 32-bit unsigned integer
246    UInt32,
247    /// 64-bit unsigned integer
248    UInt64,
249    /// 16-bit floating point
250    Float16,
251    /// 32-bit floating point
252    Float32,
253    /// 64-bit floating point
254    Float64,
255    /// Fixed-size list
256    FixedSizeList(Box<ArrowDataType>, usize),
257}
258
259impl From<DType> for ArrowDataType {
260    fn from(dtype: DType) -> Self {
261        match dtype {
262            DType::Bool => ArrowDataType::Boolean,
263            DType::I8 | DType::QInt8 => ArrowDataType::Int8,
264            DType::U8 | DType::QUInt8 => ArrowDataType::UInt8,
265            DType::I16 => ArrowDataType::Int16,
266            DType::I32 | DType::QInt32 => ArrowDataType::Int32,
267            DType::I64 => ArrowDataType::Int64,
268            DType::U32 => ArrowDataType::UInt32,
269            DType::U64 => ArrowDataType::UInt64,
270            DType::F16 => ArrowDataType::Float16,
271            DType::F32 => ArrowDataType::Float32,
272            DType::F64 => ArrowDataType::Float64,
273            DType::BF16 => ArrowDataType::Float32, // Best approximation
274            DType::C64 => ArrowDataType::FixedSizeList(Box::new(ArrowDataType::Float32), 2),
275            DType::C128 => ArrowDataType::FixedSizeList(Box::new(ArrowDataType::Float64), 2),
276        }
277    }
278}
279
280/// Conversion utilities
281pub struct ConversionUtils;
282
283impl ConversionUtils {
284    /// Convert ToRSh shape to NumPy shape
285    pub fn torsh_shape_to_numpy(shape: &Shape) -> Vec<usize> {
286        shape.dims().to_vec()
287    }
288
289    /// Convert NumPy shape to ToRSh shape
290    pub fn numpy_shape_to_torsh(shape: Vec<usize>) -> Result<Shape> {
291        Ok(Shape::new(shape))
292    }
293
294    /// Check if two arrays are memory layout compatible
295    pub fn is_layout_compatible(
296        shape1: &[usize],
297        strides1: &[isize],
298        shape2: &[usize],
299        strides2: &[isize],
300    ) -> bool {
301        if shape1.len() != shape2.len() || shape1 != shape2 {
302            return false;
303        }
304
305        strides1 == strides2
306    }
307
308    /// Compute memory layout efficiency score (0.0 to 1.0)
309    pub fn layout_efficiency_score(shape: &[usize], strides: &[isize], itemsize: usize) -> f64 {
310        if shape.is_empty() {
311            return 1.0;
312        }
313
314        // Check if it's C-contiguous
315        let c_strides = NumpyArrayInfo::compute_c_strides(shape, itemsize);
316        if strides == c_strides {
317            return 1.0;
318        }
319
320        // Check if it's Fortran-contiguous
321        let f_strides = NumpyArrayInfo::compute_f_strides(shape, itemsize);
322        if strides == f_strides {
323            return 0.9;
324        }
325
326        // Compute efficiency based on stride patterns
327        let total_elements: usize = shape.iter().product();
328        let expected_size = total_elements * itemsize;
329        let actual_span = Self::compute_memory_span(shape, strides, itemsize);
330
331        if actual_span == 0 {
332            return 0.0;
333        }
334
335        (expected_size as f64 / actual_span as f64).min(1.0)
336    }
337
338    /// Compute the span of memory used by an array
339    fn compute_memory_span(shape: &[usize], strides: &[isize], itemsize: usize) -> usize {
340        if shape.is_empty() {
341            return 0;
342        }
343
344        let mut min_offset = 0isize;
345        let mut max_offset = 0isize;
346
347        for (&dim, &stride) in shape.iter().zip(strides.iter()) {
348            if dim > 1 {
349                let offset = stride * (dim as isize - 1);
350                min_offset = min_offset.min(offset);
351                max_offset = max_offset.max(offset);
352            }
353        }
354
355        (max_offset - min_offset) as usize + itemsize
356    }
357}
358
359/// Documentation utilities for the interop module
360pub struct InteropDocs;
361
362impl InteropDocs {
363    /// Generate documentation for supported conversions
364    pub fn supported_conversions() -> String {
365        let conversions = vec![
366            ("NumPy", "ndarray", "Zero-copy when C-contiguous"),
367            ("ndarray", "Array", "Zero-copy when contiguous"),
368            ("ONNX", "TensorProto", "Schema mapping"),
369            ("Arrow", "Array", "Type mapping with metadata"),
370            ("Rust", "Vec<T>", "Direct conversion"),
371        ];
372
373        let mut doc = String::from("Supported Tensor Format Conversions:\n");
374        doc.push_str("=========================================\n\n");
375
376        for (from, to, notes) in conversions {
377            doc.push_str(&format!("• {from} ↔ {to}: {notes}\n"));
378        }
379
380        doc
381    }
382
383    /// Generate examples for common conversion patterns
384    pub fn conversion_examples() -> String {
385        r#"
386Conversion Examples:
387==================
388
389// NumPy-style array info
390let numpy_info = NumpyArrayInfo::new(vec![2, 3, 4], DType::F32);
391assert!(numpy_info.c_contiguous);
392
393// ONNX type conversion
394let onnx_type = OnnxDataType::from(DType::F32);
395let back_to_dtype = DType::try_from(onnx_type).unwrap();
396
397// Arrow type conversion
398let arrow_type = ArrowDataType::from(DType::C64);
399match arrow_type {
400    ArrowDataType::FixedSizeList(inner, size) => {
401        assert_eq!(size, 2); // Real and imaginary parts
402    }
403    _ => panic!("Unexpected type"),
404}
405
406// Layout efficiency checking
407let shape = vec![1000, 1000];
408let c_strides = vec![4000, 4]; // C-contiguous for f32
409let efficiency = ConversionUtils::layout_efficiency_score(&shape, &c_strides, 4);
410assert_eq!(efficiency, 1.0);
411"#
412        .to_string()
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn test_numpy_array_info() {
422        let info = NumpyArrayInfo::new(vec![2, 3, 4], DType::F32);
423        assert_eq!(info.shape, vec![2, 3, 4]);
424        assert_eq!(info.strides, vec![48, 16, 4]); // C-contiguous strides for f32
425        assert!(info.c_contiguous);
426        assert!(!info.f_contiguous);
427        assert_eq!(info.nbytes, 96); // 2*3*4*4 bytes
428    }
429
430    #[test]
431    fn test_onnx_dtype_conversion() {
432        // Test round-trip conversion
433        let dtypes = vec![
434            DType::F32,
435            DType::F64,
436            DType::I8,
437            DType::U8,
438            DType::I32,
439            DType::Bool,
440            DType::C64,
441        ];
442
443        for dtype in dtypes {
444            let onnx_type = OnnxDataType::from(dtype);
445            let back_to_dtype = DType::try_from(onnx_type).expect("try_from should succeed");
446            assert_eq!(dtype, back_to_dtype);
447        }
448    }
449
450    #[test]
451    fn test_arrow_dtype_conversion() {
452        assert_eq!(ArrowDataType::from(DType::F32), ArrowDataType::Float32);
453        assert_eq!(ArrowDataType::from(DType::Bool), ArrowDataType::Boolean);
454
455        // Test complex types
456        match ArrowDataType::from(DType::C64) {
457            ArrowDataType::FixedSizeList(inner, size) => {
458                assert_eq!(*inner, ArrowDataType::Float32);
459                assert_eq!(size, 2);
460            }
461            _ => panic!("Expected FixedSizeList for C64"),
462        }
463    }
464
465    #[test]
466    fn test_layout_efficiency() {
467        let shape = vec![10, 10];
468        let itemsize = 4;
469
470        // C-contiguous (perfect efficiency)
471        let c_strides = vec![40, 4];
472        let efficiency = ConversionUtils::layout_efficiency_score(&shape, &c_strides, itemsize);
473        assert_eq!(efficiency, 1.0);
474
475        // F-contiguous (very good efficiency)
476        let f_strides = vec![4, 40];
477        let efficiency = ConversionUtils::layout_efficiency_score(&shape, &f_strides, itemsize);
478        assert_eq!(efficiency, 0.9);
479    }
480
481    #[test]
482    fn test_conversion_utils() {
483        let shape = Shape::new(vec![2, 3, 4]);
484        let numpy_shape = ConversionUtils::torsh_shape_to_numpy(&shape);
485        assert_eq!(numpy_shape, vec![2, 3, 4]);
486
487        let back_to_shape = ConversionUtils::numpy_shape_to_torsh(numpy_shape)
488            .expect("numpy_shape_to_torsh should succeed");
489        assert_eq!(shape.dims(), back_to_shape.dims());
490    }
491
492    #[test]
493    fn test_layout_compatibility() {
494        let shape1 = vec![2, 3];
495        let strides1 = vec![12, 4];
496        let shape2 = vec![2, 3];
497        let strides2 = vec![12, 4];
498
499        assert!(ConversionUtils::is_layout_compatible(
500            &shape1, &strides1, &shape2, &strides2
501        ));
502
503        let strides3 = vec![4, 8]; // Different strides
504        assert!(!ConversionUtils::is_layout_compatible(
505            &shape1, &strides1, &shape2, &strides3
506        ));
507    }
508}