Skip to main content

wolfram_expr/
array_buf.rs

1//! Shared implementation backing both [`NumericArray`][crate::NumericArray] and
2//! [`PackedArray`][crate::PackedArray].
3//!
4//! Both types are dense N-dimensional buffers with an element-type tag — the only
5//! difference is the *set* of valid element types (PackedArray supports a strict
6//! subset).
7
8use crate::wxf::NumericArrayEnum;
9use crate::ByteArray;
10
11/// Sealed marker for Rust primitives valid as an array element. The set is
12/// fixed (the C ABI only knows these widths), so external types can't
13/// implement [`ArrayElement`].
14mod sealed {
15    use crate::complex::{Complex32, Complex64};
16    pub trait Sealed {}
17    impl Sealed for i8 {}
18    impl Sealed for i16 {}
19    impl Sealed for i32 {}
20    impl Sealed for i64 {}
21    impl Sealed for u8 {}
22    impl Sealed for u16 {}
23    impl Sealed for u32 {}
24    impl Sealed for u64 {}
25    impl Sealed for f32 {}
26    impl Sealed for f64 {}
27    impl Sealed for Complex32 {}
28    impl Sealed for Complex64 {}
29}
30
31/// Connects a Rust primitive to its element-type discriminant. Implemented
32/// once per `(type, tag)` pair: e.g. `i32: ArrayElement<NumericArrayEnum>`
33/// (with `TAG = Integer32`) and `i32: ArrayElement<PackedArrayEnum>` (with
34/// `TAG = Integer32`). Sealed — only the primitives in [`sealed`] above can
35/// satisfy the `Sealed` super-bound.
36pub trait ArrayElement<Tag: Copy + PartialEq>: Copy + 'static + sealed::Sealed {
37    /// The element-type tag for `Self` under this array kind.
38    const TAG: Tag;
39}
40
41/// Generic dense N-dimensional buffer parameterized by an element-type tag.
42///
43/// `NumericArray = ArrayBuf<NumericArrayEnum>` and
44/// `PackedArray   = ArrayBuf<PackedArrayEnum>`. Each provides specialized
45/// constructors (`from_slice<T: …Element>`) and a typed slice view; shape,
46/// byte access, and element count are shared via this struct.
47#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
48pub struct ArrayBuf<Tag> {
49    pub(crate) data_type: Tag,
50    pub(crate) dimensions: Vec<usize>,
51    pub(crate) bytes: ByteArray,
52}
53
54impl<Tag: Copy + PartialEq> ArrayBuf<Tag> {
55    /// Construct from raw parts. Caller is responsible for ensuring
56    /// `bytes.len() == prod(dimensions) * element_size`.
57    pub fn new(data_type: Tag, dimensions: Vec<usize>, bytes: ByteArray) -> Self {
58        ArrayBuf {
59            data_type,
60            dimensions,
61            bytes,
62        }
63    }
64
65    /// The concrete element-type tag.
66    pub fn data_type(&self) -> Tag {
67        self.data_type
68    }
69
70    /// Multi-dimensional shape.
71    pub fn dimensions(&self) -> &[usize] {
72        &self.dimensions
73    }
74
75    /// Raw byte buffer.
76    pub fn as_bytes(&self) -> &[u8] {
77        &self.bytes
78    }
79
80    /// Number of dimensions.
81    pub fn rank(&self) -> usize {
82        self.dimensions.len()
83    }
84
85    /// Total element count = product of dimensions.
86    pub fn element_count(&self) -> usize {
87        self.dimensions.iter().product()
88    }
89
90    /// Total byte length of the buffer.
91    pub fn byte_count(&self) -> usize {
92        self.bytes.len()
93    }
94
95    /// Construct from a typed slice. Dimensions must satisfy
96    /// `prod(dimensions) == slice.len()`.
97    pub fn from_slice<T: ArrayElement<Tag>>(dimensions: Vec<usize>, slice: &[T]) -> Self {
98        assert_eq!(
99            dimensions.iter().product::<usize>(),
100            slice.len(),
101            "ArrayBuf::from_slice: dims product must equal slice length"
102        );
103        let bytes: &[u8] = unsafe {
104            std::slice::from_raw_parts(
105                slice.as_ptr() as *const u8,
106                std::mem::size_of_val(slice),
107            )
108        };
109        ArrayBuf::new(T::TAG, dimensions, ByteArray::from(bytes))
110    }
111
112    /// Try to view the buffer as a slice of `T`. Returns `None` if `T`'s tag
113    /// doesn't match this array's [`data_type`][Self::data_type].
114    pub fn try_as_slice<T: ArrayElement<Tag>>(&self) -> Option<&[T]> {
115        if self.data_type != T::TAG {
116            return None;
117        }
118        let bytes = self.as_bytes();
119        let elem_size = std::mem::size_of::<T>();
120        debug_assert_eq!(bytes.len() % elem_size, 0);
121        if bytes.is_empty() {
122            return Some(&[]);
123        }
124        // SAFETY: tag matches T, so the bytes were produced from a `[T]`.
125        Some(unsafe {
126            std::slice::from_raw_parts(
127                bytes.as_ptr() as *const T,
128                bytes.len() / elem_size,
129            )
130        })
131    }
132}
133
134/// Common read API implemented by both the owned [`crate::NumericArray`] /
135/// [`crate::PackedArray`] and the runtime-handle `NumericArray<T>` in
136/// `wolfram-library-link`.
137pub trait NumericArrayRead {
138    /// The element-type tag.
139    fn data_type(&self) -> NumericArrayEnum;
140    /// The multi-dimensional shape (row-major).
141    fn dimensions(&self) -> &[usize];
142    /// The flat little-endian byte buffer.
143    fn as_bytes(&self) -> &[u8];
144
145    /// Number of dimensions.
146    fn rank(&self) -> usize {
147        self.dimensions().len()
148    }
149    /// Total element count = product of dimensions.
150    fn element_count(&self) -> usize {
151        self.dimensions().iter().product()
152    }
153    /// Total byte length of the buffer.
154    fn byte_count(&self) -> usize {
155        self.as_bytes().len()
156    }
157    /// Bytes per element.
158    fn element_size(&self) -> usize {
159        self.data_type().size_in_bytes()
160    }
161
162    /// View the buffer as `&[T]` if `T`'s element type matches; else `None`.
163    fn try_as_slice<T: ArrayElement<NumericArrayEnum>>(&self) -> Option<&[T]> {
164        if self.data_type() != T::TAG {
165            return None;
166        }
167        let bytes = self.as_bytes();
168        let elem_size = std::mem::size_of::<T>();
169        debug_assert_eq!(bytes.len() % elem_size, 0);
170        // SAFETY: tag matches, alignment guaranteed by construction.
171        Some(unsafe {
172            std::slice::from_raw_parts(
173                bytes.as_ptr() as *const T,
174                bytes.len() / elem_size,
175            )
176        })
177    }
178}
179
180impl<Tag: Into<NumericArrayEnum> + Copy + PartialEq> NumericArrayRead for ArrayBuf<Tag> {
181    fn data_type(&self) -> NumericArrayEnum {
182        self.data_type.into()
183    }
184    fn dimensions(&self) -> &[usize] {
185        &self.dimensions
186    }
187    fn as_bytes(&self) -> &[u8] {
188        &self.bytes
189    }
190}