wolfram_library_link/
numeric_array.rs

1use std::ffi::c_void;
2use std::fmt;
3use std::marker::PhantomData;
4use std::mem::{self, MaybeUninit};
5
6use static_assertions::assert_not_impl_any;
7
8use crate::{rtl, sys};
9
10#[rustfmt::skip]
11use crate::sys::MNumericArray_Data_Type::{
12    MNumericArray_Type_Bit8 as BIT8_TYPE,
13    MNumericArray_Type_Bit16 as BIT16_TYPE,
14    MNumericArray_Type_Bit32 as BIT32_TYPE,
15    MNumericArray_Type_Bit64 as BIT64_TYPE,
16
17    MNumericArray_Type_UBit8 as UBIT8_TYPE,
18    MNumericArray_Type_UBit16 as UBIT16_TYPE,
19    MNumericArray_Type_UBit32 as UBIT32_TYPE,
20    MNumericArray_Type_UBit64 as UBIT64_TYPE,
21
22    MNumericArray_Type_Real32 as REAL32_TYPE,
23    MNumericArray_Type_Real64 as REAL64_TYPE,
24
25    MNumericArray_Type_Complex_Real32 as COMPLEX_REAL32_TYPE,
26    MNumericArray_Type_Complex_Real64 as COMPLEX_REAL64_TYPE,
27};
28
29use crate::sys::MNumericArray_Convert_Method::*;
30
31/// Native Wolfram [`NumericArray`][ref/NumericArray]<sub>WL</sub>.
32///
33/// This type is an ABI-compatible wrapper around [`wolfram_library_link_sys::MNumericArray`].
34///
35/// A [`NumericArray`] can contain any type `T` which satisfies the trait
36/// [`NumericArrayType`].
37///
38/// Use [`NumericArray::kind()`] to dynamically resolve a `NumericArray` with unknown
39/// element type into a `NumericArray<T>` with explicit element type.
40///
41/// Use [`UninitNumericArray`] to construct a [`NumericArray`] without requiring an
42/// intermediate allocation to copy the elements from.
43///
44/// [ref/NumericArray]: https://reference.wolfram.com/language/ref/NumericArray.html
45#[repr(transparent)]
46#[derive(ref_cast::RefCast)]
47pub struct NumericArray<T = ()>(sys::MNumericArray, PhantomData<T>);
48
49/// Represents an allocated [`NumericArray`] whose elements have not yet been initialized.
50///
51/// Use [`as_slice_mut()`][`UninitNumericArray::as_slice_mut()`] to initialize the
52/// elements of this [`UninitNumericArray`].
53pub struct UninitNumericArray<T: NumericArrayType>(sys::MNumericArray, PhantomData<T>);
54
55// Guard against accidental `derive(Copy)` annotations.
56assert_not_impl_any!(NumericArray: Copy);
57assert_not_impl_any!(UninitNumericArray<i64>: Copy);
58
59//======================================
60// Traits
61//======================================
62
63/// Trait implemented for types that can be stored in a [`NumericArray`].
64///
65/// Those types are:
66///
67///   * [`u8`], [`u16`], [`u32`], [`u64`]
68///   * [`i8`], [`i16`], [`i32`], [`i64`]
69///   * [`f32`], [`f64`]
70///   * [`mcomplex`][sys::mcomplex]
71///
72/// [`NumericArrayDataType`] is an enumeration of all the types which satisfy this trait.
73pub trait NumericArrayType: private::Sealed {
74    /// The [`NumericArrayDataType`] which dynamically represents the type which this
75    /// trait is implemented for.
76    const TYPE: NumericArrayDataType;
77}
78
79mod private {
80    use crate::sys;
81
82    pub trait Sealed {}
83
84    impl Sealed for u8 {}
85    impl Sealed for u16 {}
86    impl Sealed for u32 {}
87    impl Sealed for u64 {}
88
89    impl Sealed for i8 {}
90    impl Sealed for i16 {}
91    impl Sealed for i32 {}
92    impl Sealed for i64 {}
93
94    impl Sealed for f32 {}
95    impl Sealed for f64 {}
96
97    // impl Sealed for sys::complexreal32 {}
98    impl Sealed for sys::mcomplex {}
99}
100
101impl NumericArrayType for i8 {
102    const TYPE: NumericArrayDataType = NumericArrayDataType::Bit8;
103}
104impl NumericArrayType for i16 {
105    const TYPE: NumericArrayDataType = NumericArrayDataType::Bit16;
106}
107impl NumericArrayType for i32 {
108    const TYPE: NumericArrayDataType = NumericArrayDataType::Bit32;
109}
110impl NumericArrayType for i64 {
111    const TYPE: NumericArrayDataType = NumericArrayDataType::Bit64;
112}
113
114impl NumericArrayType for u8 {
115    const TYPE: NumericArrayDataType = NumericArrayDataType::UBit8;
116}
117impl NumericArrayType for u16 {
118    const TYPE: NumericArrayDataType = NumericArrayDataType::UBit16;
119}
120impl NumericArrayType for u32 {
121    const TYPE: NumericArrayDataType = NumericArrayDataType::UBit32;
122}
123impl NumericArrayType for u64 {
124    const TYPE: NumericArrayDataType = NumericArrayDataType::UBit64;
125}
126
127impl NumericArrayType for f32 {
128    const TYPE: NumericArrayDataType = NumericArrayDataType::Real32;
129}
130impl NumericArrayType for f64 {
131    const TYPE: NumericArrayDataType = NumericArrayDataType::Real64;
132}
133
134// TODO: Why is there no WolframLibrary.h typedef for 32-bit complex reals?
135// impl NumericArrayType for sys::complexreal32 {
136//     const TYPE: NumericArrayDataType = NumericArrayDataType::ComplexReal32;
137// }
138impl NumericArrayType for sys::mcomplex {
139    const TYPE: NumericArrayDataType = NumericArrayDataType::ComplexReal64;
140}
141
142//======================================
143// Enums
144//======================================
145
146/// The type of the data being stored in a [`NumericArray`].
147///
148/// This is an enumeration of all the types which satisfy [`NumericArrayType`].
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
150#[repr(u32)]
151#[allow(missing_docs)]
152pub enum NumericArrayDataType {
153    Bit8 = BIT8_TYPE as u32,
154    Bit16 = BIT16_TYPE as u32,
155    Bit32 = BIT32_TYPE as u32,
156    Bit64 = BIT64_TYPE as u32,
157
158    UBit8 = UBIT8_TYPE as u32,
159    UBit16 = UBIT16_TYPE as u32,
160    UBit32 = UBIT32_TYPE as u32,
161    UBit64 = UBIT64_TYPE as u32,
162
163    Real32 = REAL32_TYPE as u32,
164    Real64 = REAL64_TYPE as u32,
165
166    ComplexReal32 = COMPLEX_REAL32_TYPE as u32,
167    ComplexReal64 = COMPLEX_REAL64_TYPE as u32,
168}
169
170/// Conversion method used by [`NumericArray::convert_to()`].
171#[derive(Debug, Copy, Clone, PartialEq, Eq)]
172#[repr(u32)]
173#[allow(missing_docs)]
174pub enum NumericArrayConvertMethod {
175    Cast = MNumericArray_Convert_Cast as u32,
176    Check = MNumericArray_Convert_Check as u32,
177    Coerce = MNumericArray_Convert_Coerce as u32,
178    Round = MNumericArray_Convert_Round as u32,
179    Scale = MNumericArray_Convert_Scale as u32,
180    ClipAndCast = MNumericArray_Convert_Clip_Cast as u32,
181    ClipAndCheck = MNumericArray_Convert_Clip_Check as u32,
182    ClipAndCoerce = MNumericArray_Convert_Clip_Coerce as u32,
183    ClipAndRound = MNumericArray_Convert_Clip_Round as u32,
184    ClipAndScale = MNumericArray_Convert_Clip_Scale as u32,
185}
186
187/// Data array borrowed from a [`NumericArray`].
188///
189/// Use [`NumericArray::kind()`] to get an instance of this type.
190#[allow(missing_docs)]
191pub enum NumericArrayKind<'e> {
192    //
193    // Signed integer types
194    //
195    Bit8(&'e NumericArray<i8>),
196    Bit16(&'e NumericArray<i16>),
197    Bit32(&'e NumericArray<i32>),
198    Bit64(&'e NumericArray<i64>),
199
200    //
201    // Unsigned integer types
202    //
203    UBit8(&'e NumericArray<u8>),
204    UBit16(&'e NumericArray<u16>),
205    UBit32(&'e NumericArray<u32>),
206    UBit64(&'e NumericArray<u64>),
207
208    //
209    // Real types
210    //
211    Real32(&'e NumericArray<f32>),
212    Real64(&'e NumericArray<f64>),
213
214    //
215    // Complex types
216    //
217    // ComplexReal32(&'e NumericArray<sys::complexreal32>),
218    ComplexReal64(&'e NumericArray<sys::mcomplex>),
219}
220
221// Assert that `sys::mcomplex` is the 64-bit complex real type and not a 32-bit complex
222// real type.
223const _: () = assert!(mem::size_of::<sys::mcomplex>() == mem::size_of::<[f64; 2]>());
224const _: () = assert!(mem::align_of::<sys::mcomplex>() == mem::align_of::<f64>());
225
226//======================================
227// Impls
228//======================================
229
230impl NumericArray {
231    /// Dynamically resolve a `NumericArray` of unknown element type into a
232    /// `NumericArray<T>` with explicit element type.
233    ///
234    /// # Example
235    ///
236    /// Implement a function which returns the sum of an integral `NumericArray`
237    ///
238    /// ```no_run
239    /// use wolfram_library_link::{NumericArray, NumericArrayKind};
240    ///
241    /// fn sum(array: NumericArray) -> i64 {
242    ///     match array.kind() {
243    ///         NumericArrayKind::Bit8(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
244    ///         NumericArrayKind::Bit16(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
245    ///         NumericArrayKind::Bit32(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
246    ///         NumericArrayKind::Bit64(na) => na.as_slice().into_iter().sum(),
247    ///         NumericArrayKind::UBit8(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
248    ///         NumericArrayKind::UBit16(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
249    ///         NumericArrayKind::UBit32(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
250    ///         NumericArrayKind::UBit64(na) => {
251    ///             match i64::try_from(na.as_slice().into_iter().sum::<u64>()) {
252    ///                 Ok(sum) => sum,
253    ///                 Err(_) => panic!("overflows i64"),
254    ///             }
255    ///         },
256    ///         NumericArrayKind::Real32(_)
257    ///         | NumericArrayKind::Real64(_)
258    ///         | NumericArrayKind::ComplexReal64(_) => panic!("bad type"),
259    ///     }
260    /// }
261    /// ```
262    pub fn kind(&self) -> NumericArrayKind {
263        /// The purpose of this intermediate function is to limit the scope of the call to
264        /// transmute(). `transmute()` is a *very* unsafe function, so it seems prudent to
265        /// future-proof this code against accidental changes which alter the inferrence
266        /// of the transmute() target type.
267        unsafe fn trans<T: NumericArrayType>(array: &NumericArray) -> &NumericArray<T> {
268            std::mem::transmute(array)
269        }
270
271        unsafe {
272            use NumericArrayDataType::*;
273
274            match self.data_type() {
275                Bit8 => NumericArrayKind::Bit8(trans(self)),
276                Bit16 => NumericArrayKind::Bit16(trans(self)),
277                Bit32 => NumericArrayKind::Bit32(trans(self)),
278                Bit64 => NumericArrayKind::Bit64(trans(self)),
279
280                UBit8 => NumericArrayKind::UBit8(trans(self)),
281                UBit16 => NumericArrayKind::UBit16(trans(self)),
282                UBit32 => NumericArrayKind::UBit32(trans(self)),
283                UBit64 => NumericArrayKind::UBit64(trans(self)),
284
285                Real32 => NumericArrayKind::Real32(trans(self)),
286                Real64 => NumericArrayKind::Real64(trans(self)),
287
288                // TODO: Handle this case? Is there a 32 bit complex real typedef?
289                ComplexReal32 => unimplemented!(
290                    "NumericArray::kind(): NumericArray of ComplexReal32 is not currently supported."
291                ),
292                // ComplexReal32 => NumericArrayKind::ComplexReal32(trans(self)),
293                ComplexReal64 => NumericArrayKind::ComplexReal64(trans(self)),
294            }
295        }
296    }
297
298    /// Attempt to resolve this `NumericArray` into a `&NumericArray<T>` of the specified
299    /// element type.
300    ///
301    /// If the element type of this array does not match `T`, an error will be returned.
302    ///
303    /// # Example
304    ///
305    /// Implement a function which unwraps the `&[u8]` data in a `NumericArray` of 8-bit
306    /// integers.
307    ///
308    /// ```no_run
309    /// use wolfram_library_link::NumericArray;
310    ///
311    /// fn bytes(array: &NumericArray) -> &[u8] {
312    ///     let byte_array: &NumericArray<u8> = match array.try_kind::<u8>() {
313    ///         Ok(array) => array,
314    ///         Err(_) => panic!("expected NumericArray of UnsignedInteger8")
315    ///     };
316    ///
317    ///     byte_array.as_slice()
318    /// }
319    /// ```
320    pub fn try_kind<T>(&self) -> Result<&NumericArray<T>, ()>
321    where
322        T: NumericArrayType,
323    {
324        /// The purpose of this intermediate function is to limit the scope of the call to
325        /// transmute(). `transmute()` is a *very* unsafe function, so it seems prudent to
326        /// future-proof this code against accidental changes which alter the inferrence
327        /// of the transmute() target type.
328        unsafe fn trans<T: NumericArrayType>(array: &NumericArray) -> &NumericArray<T> {
329            std::mem::transmute(array)
330        }
331
332        if self.data_type() == T::TYPE {
333            return Ok(unsafe { trans(self) });
334        }
335
336        Err(())
337    }
338
339    /// Attempt to resolve this `NumericArray` into a `NumericArray<T>` of the specified
340    /// element type.
341    ///
342    /// If the element type of this array does not match `T`, the original untyped array
343    /// will be returned as the error value.
344    pub fn try_into_kind<T>(self) -> Result<NumericArray<T>, NumericArray>
345    where
346        T: NumericArrayType,
347    {
348        /// The purpose of this intermediate function is to limit the scope of the call to
349        /// transmute(). `transmute()` is a *very* unsafe function, so it seems prudent to
350        /// future-proof this code against accidental changes which alter the inferrence
351        /// of the transmute() target type.
352        unsafe fn trans<T: NumericArrayType>(array: NumericArray) -> NumericArray<T> {
353            std::mem::transmute(array)
354        }
355
356        if self.data_type() == T::TYPE {
357            return Ok(unsafe { trans(self) });
358        }
359
360        Err(self)
361    }
362}
363
364impl<T: NumericArrayType> NumericArray<T> {
365    /// Construct a new one-dimensional [`NumericArray`] from a slice.
366    ///
367    /// Use [`NumericArray::from_array()`] to construct multidimensional numeric arrays.
368    ///
369    /// # Panics
370    ///
371    /// This function will panic if [`NumericArray::try_from_array()`] returns
372    /// an error.
373    ///
374    /// # Example
375    ///
376    /// ```no_run
377    /// # use wolfram_library_link::NumericArray;
378    /// let array = NumericArray::from_slice(&[1, 2, 3, 4, 5]);
379    /// ```
380    ///
381    /// # Alternatives
382    ///
383    /// [`UninitNumericArray`] can be used to allocate a mutable numeric array,
384    /// eliminating the need for an intermediate allocation.
385    pub fn from_slice(data: &[T]) -> NumericArray<T> {
386        NumericArray::<T>::try_from_slice(data)
387            .expect("failed to create NumericArray from slice")
388    }
389
390    /// Fallible alternative to [`NumericArray::from_slice()`].
391    pub fn try_from_slice(data: &[T]) -> Result<NumericArray<T>, sys::errcode_t> {
392        let dim1 = data.len();
393
394        NumericArray::try_from_array(&[dim1], data)
395    }
396
397    /// Construct a new multidimensional [`NumericArray`] from a list of dimensions and the
398    /// flat slice of data.
399    ///
400    /// # Panics
401    ///
402    /// This function will panic if [`NumericArray::try_from_array()`] returns
403    /// an error.
404    ///
405    /// # Example
406    ///
407    /// Construct the 2x2 [`NumericArray`] `{{1, 2}, {3, 4}}` from a list of dimensions and
408    /// a flat buffer.
409    ///
410    /// ```no_run
411    /// # use wolfram_library_link::NumericArray;
412    /// let array = NumericArray::from_array(&[2, 2], &[1, 2, 3, 4]);
413    /// ```
414    pub fn from_array(dimensions: &[usize], data: &[T]) -> NumericArray<T> {
415        NumericArray::<T>::try_from_array(dimensions, data)
416            .expect("failed to create NumericArray from array")
417    }
418
419    /// Fallible alternative to [`NumericArray::from_array()`].
420    ///
421    /// This function will return an error if:
422    ///
423    /// * `dimensions` is empty
424    /// * the product of `dimensions` is 0
425    /// * `data.len()` is not equal to the product of `dimensions`
426    pub fn try_from_array(
427        dimensions: &[usize],
428        data: &[T],
429    ) -> Result<NumericArray<T>, sys::errcode_t> {
430        let uninit = UninitNumericArray::try_from_dimensions(dimensions)?;
431
432        Ok(uninit.init_from_slice(data))
433    }
434
435    /// Access the elements stored in this [`NumericArray`] as a flat buffer.
436    pub fn as_slice(&self) -> &[T] {
437        let ptr: *mut c_void = self.data_ptr();
438
439        debug_assert!(!ptr.is_null());
440
441        // Assert that `ptr` is aligned to `T`.
442        debug_assert!(ptr as usize % std::mem::size_of::<T>() == 0);
443
444        let ptr = ptr as *const T;
445
446        unsafe { std::slice::from_raw_parts(ptr, self.flattened_length()) }
447    }
448
449    /// Access the elements stored in this [`NumericArray`] as a mutable flat buffer.
450    ///
451    /// If the [`share_count()`][NumericArray::share_count] of this array is >= 1, this
452    /// function will return `None`.
453    pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
454        if self.share_count() == 0 {
455            // This is not a shared numeric array. We have unique access to it's data.
456            unsafe { Some(self.as_slice_mut_unchecked()) }
457        } else {
458            None
459        }
460    }
461
462    /// Access the elements stored in this [`NumericArray`] as a mutable flat buffer.
463    ///
464    /// # Safety
465    ///
466    /// `NumericArray` is an immutable shared data structure. There is no robust, easy way
467    /// to determine whether mutation of a `NumericArray` is safe. Prefer to use
468    /// [`UninitNumericArray`] to create and initialize a numeric array value instead of
469    /// mutating an existing `NumericArray`.
470    pub unsafe fn as_slice_mut_unchecked(&mut self) -> &mut [T] {
471        let ptr: *mut c_void = self.data_ptr();
472
473        debug_assert!(!ptr.is_null());
474
475        // Assert that `ptr` is aligned to `T`.
476        debug_assert!(ptr as usize % std::mem::size_of::<T>() == 0);
477
478        let ptr = ptr as *mut T;
479
480        std::slice::from_raw_parts_mut(ptr, self.flattened_length())
481    }
482}
483
484impl<T> NumericArray<T> {
485    /// Erase the concrete `T` data type associated with this `NumericArray`.
486    ///
487    /// Use [`NumericArray::try_into_kind()`] to convert back into a `NumericArray<T>`.
488    ///
489    /// # Example
490    ///
491    /// ```no_run
492    /// # use wolfram_library_link::NumericArray;
493    /// let array: NumericArray<i64> = NumericArray::from_slice(&[1, 2, 3]);
494    ///
495    /// let array: NumericArray = array.into_generic();
496    /// ```
497    pub fn into_generic(self) -> NumericArray {
498        let NumericArray(na, PhantomData) = self;
499
500        // Don't run Drop on `self`; ownership of this value is being given to the caller.
501        std::mem::forget(self);
502
503        NumericArray(na, PhantomData)
504    }
505
506    /// Construct a `NumericArray<T>` from a raw [`MNumericArray`][sys::MNumericArray].
507    ///
508    /// # Safety
509    ///
510    /// The following conditions must be met for safe usage of this function:
511    ///
512    /// * `array` must be a fully initialized and valid numeric array object
513    /// * `T` must either:
514    ///   - be `()`, representing an array with dynamic element type, or
515    ///   - `T` must satisfy [`NumericArrayType`], and the element type of `array` must
516    ///     be the same as `T`.
517    // TODO: Add something about the reference count in the above list?
518    pub unsafe fn from_raw(array: sys::MNumericArray) -> NumericArray<T> {
519        NumericArray(array, PhantomData)
520    }
521
522    /// Convert this `NumericArray` into a raw [`MNumericArray`][sys::MNumericArray]
523    /// object.
524    pub unsafe fn into_raw(self) -> sys::MNumericArray {
525        let NumericArray(raw, PhantomData) = self;
526
527        // Don't run Drop on `self`; ownership of this value is being given to the caller.
528        std::mem::forget(self);
529
530        raw
531    }
532
533    /// *LibraryLink C API Documentation:* [`MNumericArray_getData`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getData.html)
534    pub fn data_ptr(&self) -> *mut c_void {
535        let NumericArray(numeric_array, _) = *self;
536
537        unsafe { data_ptr(numeric_array) }
538    }
539
540    #[allow(missing_docs)]
541    pub fn data_type(&self) -> NumericArrayDataType {
542        let value: sys::numericarray_data_t = self.data_type_raw();
543        let value: u32 = value as u32;
544
545        NumericArrayDataType::try_from(value)
546            .expect("NumericArray tensor property type is value is not a known NumericArrayDataType variant")
547    }
548
549    /// *LibraryLink C API Documentation:* [`MNumericArray_getType`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getType.html)
550    pub fn data_type_raw(&self) -> sys::numericarray_data_t {
551        let NumericArray(numeric_array, _) = *self;
552
553        unsafe { rtl::MNumericArray_getType(numeric_array) }
554    }
555
556    /// The number of elements in the underlying flat data array.
557    ///
558    /// This is the product of the dimension lengths of this [`NumericArray`].
559    ///
560    /// This is *not* the number of bytes.
561    ///
562    /// *LibraryLink C API Documentation:* [`MNumericArray_getFlattenedLength`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getFlattenedLength.html)
563    pub fn flattened_length(&self) -> usize {
564        let NumericArray(numeric_array, _) = *self;
565
566        let len = unsafe { flattened_length(numeric_array) };
567
568        // Check that the stored length matches the length computed from the dimensions.
569        debug_assert!(len == self.dimensions().iter().copied().product::<usize>());
570
571        len
572    }
573
574    /// *LibraryLink C API Documentation:* [`MNumericArray_getRank`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getRank.html)
575    pub fn rank(&self) -> usize {
576        let NumericArray(numeric_array, _) = *self;
577
578        let rank: sys::mint = unsafe { rtl::MNumericArray_getRank(numeric_array) };
579
580        let rank = usize::try_from(rank).expect("NumericArray rank overflows usize");
581
582        rank
583    }
584
585    /// Get the dimensions of this `NumericArray`.
586    ///
587    /// *LibraryLink C API Documentation:* [`MNumericArray_getDimensions`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getDimensions.html)
588    ///
589    /// # Example
590    ///
591    /// ```no_run
592    /// # use wolfram_library_link::NumericArray;
593    /// let array = NumericArray::from_array(&[2, 2], &[1, 2, 3, 4]);
594    ///
595    /// assert_eq!(array.dimensions(), &[2, 2]);
596    /// assert_eq!(array.rank(), array.dimensions().len());
597    /// ```
598    pub fn dimensions(&self) -> &[usize] {
599        let NumericArray(numeric_array, _) = *self;
600
601        let rank = self.rank();
602
603        debug_assert!(rank != 0);
604
605        let dims: *const crate::sys::mint =
606            unsafe { rtl::MNumericArray_getDimensions(numeric_array) };
607
608        const _: () = assert!(mem::size_of::<sys::mint>() == mem::size_of::<usize>());
609        let dims: *mut usize = dims as *mut usize;
610
611        debug_assert!(!dims.is_null());
612
613        unsafe { std::slice::from_raw_parts(dims, rank) }
614    }
615
616    /// Returns the share count of this `NumericArray`.
617    ///
618    /// If this `NumericArray` is not shared, the share count is 0.
619    ///
620    /// If this `NumericArray` was passed into the current library "by reference" due to
621    /// use of the `Automatic` or `"Constant"` memory management strategy, that reference
622    /// is not reflected in the `share_count()`.
623    ///
624    /// *LibraryLink C API Documentation:* [`MNumericArray_shareCount`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_shareCount.html)
625    pub fn share_count(&self) -> usize {
626        let NumericArray(raw, PhantomData) = *self;
627
628        let count: sys::mint = unsafe { rtl::MNumericArray_shareCount(raw) };
629
630        usize::try_from(count).expect("NumericArray share count mint overflows usize")
631    }
632
633    /// Returns true if `self` and `other` are pointers to the name underlying
634    /// numeric array object.
635    pub fn ptr_eq<T2>(&self, other: &NumericArray<T2>) -> bool {
636        let NumericArray(this, PhantomData) = *self;
637        let NumericArray(other, PhantomData) = *other;
638
639        this == other
640    }
641
642    /// *LibraryLink C API Documentation:* [`MNumericArray_convertType`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_convertType.html)
643    // TODO: When can this return an error? ClipAndCheck and the tolerance is not sufficient?
644    // TODO: Return a better error than `errcode_t`.
645    pub fn convert_to<T2: NumericArrayType>(
646        &self,
647        method: NumericArrayConvertMethod,
648        tolerance: sys::mreal,
649    ) -> Result<NumericArray<T2>, sys::errcode_t> {
650        let NumericArray(self_raw, PhantomData) = *self;
651
652        let mut new_raw: sys::MNumericArray = std::ptr::null_mut();
653
654        let err_code: sys::errcode_t = unsafe {
655            rtl::MNumericArray_convertType(
656                &mut new_raw,
657                self_raw,
658                T2::TYPE.as_raw(),
659                method.as_raw(),
660                tolerance,
661            )
662        };
663
664        if err_code != 0 || new_raw.is_null() {
665            return Err(err_code);
666        }
667
668        Ok(unsafe { NumericArray::<T2>::from_raw(new_raw) })
669    }
670}
671
672unsafe fn data_ptr(numeric_array: sys::MNumericArray) -> *mut c_void {
673    rtl::MNumericArray_getData(numeric_array)
674}
675
676unsafe fn flattened_length(numeric_array: sys::MNumericArray) -> usize {
677    let len: sys::mint = rtl::MNumericArray_getFlattenedLength(numeric_array);
678
679    let len = usize::try_from(len).expect("i64 overflows usize");
680
681    len
682}
683
684//======================================
685// UninitNumericArray
686//======================================
687
688impl<T: NumericArrayType> UninitNumericArray<T> {
689    /// Construct a new uninitialized `NumericArray` with the specified dimensions.
690    ///
691    /// # Panics
692    ///
693    /// This function will panic if [`UninitNumericArray::try_from_dimensions()`] returns
694    /// an error.
695    pub fn from_dimensions(dimensions: &[usize]) -> UninitNumericArray<T> {
696        UninitNumericArray::try_from_dimensions(dimensions)
697            .expect("failed to create UninitNumericArray from dimensions")
698    }
699
700    /// Try to construct a new uninitialized NumericArray with the specified dimensions.
701    ///
702    /// This function will return an error if:
703    ///
704    /// * `dimensions` is empty.
705    /// * the product of `dimensions` is equal to 0.
706    /// * the underlying allocation function returns `NULL`.
707    pub fn try_from_dimensions(
708        dimensions: &[usize],
709    ) -> Result<UninitNumericArray<T>, sys::errcode_t> {
710        assert!(!dimensions.is_empty());
711
712        let rank = dimensions.len();
713        debug_assert!(rank > 0);
714
715        unsafe {
716            let mut numeric_array: sys::MNumericArray = std::ptr::null_mut();
717
718            let err_code: sys::errcode_t = rtl::MNumericArray_new(
719                <T as NumericArrayType>::TYPE.as_raw(),
720                i64::try_from(rank).expect("usize overflows i64"),
721                dimensions.as_ptr() as *mut sys::mint,
722                &mut numeric_array,
723            );
724
725            if err_code != 0 || numeric_array.is_null() {
726                return Err(err_code);
727            }
728
729            Ok(UninitNumericArray(numeric_array, PhantomData))
730        }
731    }
732
733    /// # Panics
734    ///
735    /// This function will panic if `source` does not have the same length as
736    /// this array's [`as_slice_mut()`][UninitNumericArray::as_slice_mut] slice.
737    pub fn init_from_slice(mut self, source: &[T]) -> NumericArray<T> {
738        let data = self.as_slice_mut();
739
740        // Safety: copy_from_slice_uninit() unconditionally asserts that `data` and
741        //         `source` have the same number of elements, so if it succeeds we're
742        //         certain that every element of the NumericArray has been initialized.
743        copy_from_slice_uninit(source, data);
744
745        unsafe { self.assume_init() }
746    }
747
748    /// Mutable access to the elements of this [`UninitNumericArray`].
749    ///
750    /// This function returns a mutable slice of [`std::mem::MaybeUninit<T>`]. This is done
751    /// because it is undefined behavior in Rust to construct a `&` (or `&mut`) reference
752    /// to a value which has not been initialized. Note that it is undefined behavior even
753    /// if the reference is never read from. The `MaybeUninit` type explicitly makes the
754    /// compiler aware that the `T` value might not be initialized.
755    ///
756    /// # Example
757    ///
758    /// Construct the numeric array `{1, 2, 3, 4, 5}`.
759    ///
760    /// ```no_run
761    /// use wolfram_library_link::{NumericArray, UninitNumericArray};
762    ///
763    /// // Construct a `1x5` numeric array with elements of type `f64`.
764    /// let mut uninit = UninitNumericArray::<f64>::from_dimensions(&[5]);
765    ///
766    /// for (index, elem) in uninit.as_slice_mut().into_iter().enumerate() {
767    ///     elem.write(index as f64 + 1.0);
768    /// }
769    ///
770    /// // Now that we've taken responsibility for initializing every
771    /// // element of the UninitNumericArray, we've upheld the
772    /// // invariant necessary to make a call to `assume_init()` safe.
773    /// let array: NumericArray<f64> = unsafe { uninit.assume_init() };
774    /// ```
775    ///
776    /// See [`assume_init()`][UninitNumericArray::assume_init].
777    pub fn as_slice_mut(&mut self) -> &mut [MaybeUninit<T>] {
778        let UninitNumericArray(numeric_array, PhantomData) = *self;
779
780        unsafe {
781            let len = flattened_length(numeric_array);
782
783            let ptr: *mut c_void = data_ptr(numeric_array);
784            let ptr = ptr as *mut MaybeUninit<T>;
785
786            std::slice::from_raw_parts_mut(ptr, len)
787        }
788    }
789
790    /// Assume that this NumericArray's elements have been initialized.
791    ///
792    /// Use [`as_slice_mut()`][UninitNumericArray::as_slice_mut] to initialize the values
793    /// in this array.
794    ///
795    /// # Safety
796    ///
797    /// This function must only be called once all elements of this NumericArray have
798    /// been initialized. It is undefined behavior to construct a [`NumericArray`] without
799    /// first initializing the data array.
800    pub unsafe fn assume_init(self) -> NumericArray<T> {
801        let UninitNumericArray(expr, PhantomData) = self;
802
803        // Don't run Drop on `self`; ownership of this value is being given to the caller.
804        std::mem::forget(self);
805
806        NumericArray(expr, PhantomData)
807    }
808}
809
810/// This function is modeled after after the `copy_from_slice()` method on the primitive
811/// `slice` type. This can be used to initialize an [`UninitNumericArray`] from a slice of
812/// data.
813fn copy_from_slice_uninit<T>(src: &[T], dest: &mut [MaybeUninit<T>]) {
814    assert_eq!(
815        src.len(),
816        dest.len(),
817        "destination and source slices have different lengths"
818    );
819
820    unsafe {
821        std::ptr::copy_nonoverlapping(
822            src.as_ptr(),
823            dest.as_mut_ptr() as *mut T,
824            dest.len(),
825        )
826    }
827}
828
829impl NumericArrayDataType {
830    #[allow(missing_docs)]
831    pub fn as_raw(self) -> sys::numericarray_data_t {
832        self as sys::numericarray_data_t
833    }
834
835    /// Get the string name of this type, suitable for use in
836    /// [`NumericArray`][ref/NumericArray]<code>[<i>data</i>, &quot;<i>type</i>&quot;]</code>.
837    ///
838    /// [ref/NumericArray]: https://reference.wolfram.com/language/ref/NumericArray.html
839    #[rustfmt::skip]
840    pub fn name(&self) -> &'static str {
841        match self {
842            NumericArrayDataType::Bit8  => "Integer8",
843            NumericArrayDataType::Bit16 => "Integer16",
844            NumericArrayDataType::Bit32 => "Integer32",
845            NumericArrayDataType::Bit64 => "Integer64",
846
847            NumericArrayDataType::UBit8  => "UnsignedInteger8",
848            NumericArrayDataType::UBit16 => "UnsignedInteger16",
849            NumericArrayDataType::UBit32 => "UnsignedInteger32",
850            NumericArrayDataType::UBit64 => "UnsignedInteger64",
851
852            NumericArrayDataType::Real32 => "Real32",
853            NumericArrayDataType::Real64 => "Real64",
854
855            NumericArrayDataType::ComplexReal32 => "ComplexReal32",
856            NumericArrayDataType::ComplexReal64 => "ComplexReal64",
857        }
858    }
859}
860
861impl NumericArrayConvertMethod {
862    #[allow(missing_docs)]
863    pub fn as_raw(self) -> sys::numericarray_convert_method_t {
864        self as sys::numericarray_convert_method_t
865    }
866}
867
868//======================================
869// Trait Impls
870//======================================
871
872impl<T> Clone for NumericArray<T> {
873    fn clone(&self) -> NumericArray<T> {
874        let NumericArray(raw, PhantomData) = *self;
875
876        unsafe {
877            let mut new: sys::MNumericArray = std::ptr::null_mut();
878            let err_code: sys::errcode_t = rtl::MNumericArray_clone(raw, &mut new);
879
880            if err_code != 0 || new.is_null() {
881                panic!("NumericArray clone failed with error code: {}", err_code);
882            }
883
884            NumericArray::<T>::from_raw(new)
885        }
886    }
887}
888
889impl<T> Drop for NumericArray<T> {
890    fn drop(&mut self) {
891        if self.share_count() > 0 {
892            // This is a "Shared" numeric array, so we should decrement the reference
893            // count.
894            let NumericArray(raw, PhantomData) = *self;
895            unsafe { rtl::MNumericArray_disown(raw) }
896        } else {
897            // This is a "Manual" numeric array (or one created within Rust), so we should
898            // free its memory directly.
899            let NumericArray(raw, PhantomData) = *self;
900            unsafe { rtl::MNumericArray_free(raw) }
901        }
902    }
903}
904
905impl<T> fmt::Debug for NumericArray<T> {
906    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
907        f.debug_struct("NumericArray")
908            .field("raw", &self.0)
909            .field("data_type", &self.data_type())
910            .finish()
911    }
912}
913
914//======================================
915// Conversion Impls
916//======================================
917
918impl TryFrom<u32> for NumericArrayDataType {
919    type Error = ();
920
921    fn try_from(value: u32) -> Result<Self, Self::Error> {
922        // debug_assert!(u32::try_from(self.tensor_property_type()).is_ok());
923
924        #[rustfmt::skip]
925        let ok = match value {
926            _ if value == BIT8_TYPE as u32 => NumericArrayDataType::Bit8,
927            _ if value == BIT16_TYPE as u32 => NumericArrayDataType::Bit16,
928            _ if value == BIT32_TYPE as u32 => NumericArrayDataType::Bit32,
929            _ if value == BIT64_TYPE as u32 => NumericArrayDataType::Bit64,
930
931            _ if value == UBIT8_TYPE as u32 => NumericArrayDataType::UBit8,
932            _ if value == UBIT16_TYPE as u32 => NumericArrayDataType::UBit16,
933            _ if value == UBIT32_TYPE as u32 => NumericArrayDataType::UBit32,
934            _ if value == UBIT64_TYPE as u32 => NumericArrayDataType::UBit64,
935
936            _ if value == REAL32_TYPE as u32 => NumericArrayDataType::Real32,
937            _ if value == REAL64_TYPE as u32 => NumericArrayDataType::Real64,
938
939            _ if value == COMPLEX_REAL32_TYPE as u32 => NumericArrayDataType::ComplexReal32,
940            _ if value == COMPLEX_REAL64_TYPE as u32 => NumericArrayDataType::ComplexReal64,
941
942            _ => return Err(()),
943        };
944
945        Ok(ok)
946    }
947}