wolfram_library_link/numeric_array.rs
1use std::ffi::c_void;
2use std::fmt;
3use std::marker::PhantomData;
4use std::mem::{self, MaybeUninit};
5
6use static_assertions::assert_not_impl_any;
7
8use crate::{rtl, sys};
9
10#[rustfmt::skip]
11use crate::sys::MNumericArray_Data_Type::{
12 MNumericArray_Type_Bit8 as BIT8_TYPE,
13 MNumericArray_Type_Bit16 as BIT16_TYPE,
14 MNumericArray_Type_Bit32 as BIT32_TYPE,
15 MNumericArray_Type_Bit64 as BIT64_TYPE,
16
17 MNumericArray_Type_UBit8 as UBIT8_TYPE,
18 MNumericArray_Type_UBit16 as UBIT16_TYPE,
19 MNumericArray_Type_UBit32 as UBIT32_TYPE,
20 MNumericArray_Type_UBit64 as UBIT64_TYPE,
21
22 MNumericArray_Type_Real32 as REAL32_TYPE,
23 MNumericArray_Type_Real64 as REAL64_TYPE,
24
25 MNumericArray_Type_Complex_Real32 as COMPLEX_REAL32_TYPE,
26 MNumericArray_Type_Complex_Real64 as COMPLEX_REAL64_TYPE,
27};
28
29use crate::sys::MNumericArray_Convert_Method::*;
30
31/// Native Wolfram [`NumericArray`][ref/NumericArray]<sub>WL</sub>.
32///
33/// This type is an ABI-compatible wrapper around [`wolfram_library_link_sys::MNumericArray`].
34///
35/// A [`NumericArray`] can contain any type `T` which satisfies the trait
36/// [`NumericArrayType`].
37///
38/// Use [`NumericArray::kind()`] to dynamically resolve a `NumericArray` with unknown
39/// element type into a `NumericArray<T>` with explicit element type.
40///
41/// Use [`UninitNumericArray`] to construct a [`NumericArray`] without requiring an
42/// intermediate allocation to copy the elements from.
43///
44/// [ref/NumericArray]: https://reference.wolfram.com/language/ref/NumericArray.html
45#[repr(transparent)]
46#[derive(ref_cast::RefCast)]
47pub struct NumericArray<T = ()>(sys::MNumericArray, PhantomData<T>);
48
49/// Represents an allocated [`NumericArray`] whose elements have not yet been initialized.
50///
51/// Use [`as_slice_mut()`][`UninitNumericArray::as_slice_mut()`] to initialize the
52/// elements of this [`UninitNumericArray`].
53pub struct UninitNumericArray<T: NumericArrayType>(sys::MNumericArray, PhantomData<T>);
54
55// Guard against accidental `derive(Copy)` annotations.
56assert_not_impl_any!(NumericArray: Copy);
57assert_not_impl_any!(UninitNumericArray<i64>: Copy);
58
59//======================================
60// Traits
61//======================================
62
63/// Trait implemented for types that can be stored in a [`NumericArray`].
64///
65/// Those types are:
66///
67/// * [`u8`], [`u16`], [`u32`], [`u64`]
68/// * [`i8`], [`i16`], [`i32`], [`i64`]
69/// * [`f32`], [`f64`]
70/// * [`mcomplex`][sys::mcomplex]
71///
72/// [`NumericArrayDataType`] is an enumeration of all the types which satisfy this trait.
73pub trait NumericArrayType: private::Sealed {
74 /// The [`NumericArrayDataType`] which dynamically represents the type which this
75 /// trait is implemented for.
76 const TYPE: NumericArrayDataType;
77}
78
79mod private {
80 use crate::sys;
81
82 pub trait Sealed {}
83
84 impl Sealed for u8 {}
85 impl Sealed for u16 {}
86 impl Sealed for u32 {}
87 impl Sealed for u64 {}
88
89 impl Sealed for i8 {}
90 impl Sealed for i16 {}
91 impl Sealed for i32 {}
92 impl Sealed for i64 {}
93
94 impl Sealed for f32 {}
95 impl Sealed for f64 {}
96
97 // impl Sealed for sys::complexreal32 {}
98 impl Sealed for sys::mcomplex {}
99}
100
101impl NumericArrayType for i8 {
102 const TYPE: NumericArrayDataType = NumericArrayDataType::Bit8;
103}
104impl NumericArrayType for i16 {
105 const TYPE: NumericArrayDataType = NumericArrayDataType::Bit16;
106}
107impl NumericArrayType for i32 {
108 const TYPE: NumericArrayDataType = NumericArrayDataType::Bit32;
109}
110impl NumericArrayType for i64 {
111 const TYPE: NumericArrayDataType = NumericArrayDataType::Bit64;
112}
113
114impl NumericArrayType for u8 {
115 const TYPE: NumericArrayDataType = NumericArrayDataType::UBit8;
116}
117impl NumericArrayType for u16 {
118 const TYPE: NumericArrayDataType = NumericArrayDataType::UBit16;
119}
120impl NumericArrayType for u32 {
121 const TYPE: NumericArrayDataType = NumericArrayDataType::UBit32;
122}
123impl NumericArrayType for u64 {
124 const TYPE: NumericArrayDataType = NumericArrayDataType::UBit64;
125}
126
127impl NumericArrayType for f32 {
128 const TYPE: NumericArrayDataType = NumericArrayDataType::Real32;
129}
130impl NumericArrayType for f64 {
131 const TYPE: NumericArrayDataType = NumericArrayDataType::Real64;
132}
133
134// TODO: Why is there no WolframLibrary.h typedef for 32-bit complex reals?
135// impl NumericArrayType for sys::complexreal32 {
136// const TYPE: NumericArrayDataType = NumericArrayDataType::ComplexReal32;
137// }
138impl NumericArrayType for sys::mcomplex {
139 const TYPE: NumericArrayDataType = NumericArrayDataType::ComplexReal64;
140}
141
142//======================================
143// Enums
144//======================================
145
146/// The type of the data being stored in a [`NumericArray`].
147///
148/// This is an enumeration of all the types which satisfy [`NumericArrayType`].
149#[derive(Debug, Copy, Clone, PartialEq, Eq)]
150#[repr(u32)]
151#[allow(missing_docs)]
152pub enum NumericArrayDataType {
153 Bit8 = BIT8_TYPE as u32,
154 Bit16 = BIT16_TYPE as u32,
155 Bit32 = BIT32_TYPE as u32,
156 Bit64 = BIT64_TYPE as u32,
157
158 UBit8 = UBIT8_TYPE as u32,
159 UBit16 = UBIT16_TYPE as u32,
160 UBit32 = UBIT32_TYPE as u32,
161 UBit64 = UBIT64_TYPE as u32,
162
163 Real32 = REAL32_TYPE as u32,
164 Real64 = REAL64_TYPE as u32,
165
166 ComplexReal32 = COMPLEX_REAL32_TYPE as u32,
167 ComplexReal64 = COMPLEX_REAL64_TYPE as u32,
168}
169
170/// Conversion method used by [`NumericArray::convert_to()`].
171#[derive(Debug, Copy, Clone, PartialEq, Eq)]
172#[repr(u32)]
173#[allow(missing_docs)]
174pub enum NumericArrayConvertMethod {
175 Cast = MNumericArray_Convert_Cast as u32,
176 Check = MNumericArray_Convert_Check as u32,
177 Coerce = MNumericArray_Convert_Coerce as u32,
178 Round = MNumericArray_Convert_Round as u32,
179 Scale = MNumericArray_Convert_Scale as u32,
180 ClipAndCast = MNumericArray_Convert_Clip_Cast as u32,
181 ClipAndCheck = MNumericArray_Convert_Clip_Check as u32,
182 ClipAndCoerce = MNumericArray_Convert_Clip_Coerce as u32,
183 ClipAndRound = MNumericArray_Convert_Clip_Round as u32,
184 ClipAndScale = MNumericArray_Convert_Clip_Scale as u32,
185}
186
187/// Data array borrowed from a [`NumericArray`].
188///
189/// Use [`NumericArray::kind()`] to get an instance of this type.
190#[allow(missing_docs)]
191pub enum NumericArrayKind<'e> {
192 //
193 // Signed integer types
194 //
195 Bit8(&'e NumericArray<i8>),
196 Bit16(&'e NumericArray<i16>),
197 Bit32(&'e NumericArray<i32>),
198 Bit64(&'e NumericArray<i64>),
199
200 //
201 // Unsigned integer types
202 //
203 UBit8(&'e NumericArray<u8>),
204 UBit16(&'e NumericArray<u16>),
205 UBit32(&'e NumericArray<u32>),
206 UBit64(&'e NumericArray<u64>),
207
208 //
209 // Real types
210 //
211 Real32(&'e NumericArray<f32>),
212 Real64(&'e NumericArray<f64>),
213
214 //
215 // Complex types
216 //
217 // ComplexReal32(&'e NumericArray<sys::complexreal32>),
218 ComplexReal64(&'e NumericArray<sys::mcomplex>),
219}
220
221// Assert that `sys::mcomplex` is the 64-bit complex real type and not a 32-bit complex
222// real type.
223const _: () = assert!(mem::size_of::<sys::mcomplex>() == mem::size_of::<[f64; 2]>());
224const _: () = assert!(mem::align_of::<sys::mcomplex>() == mem::align_of::<f64>());
225
226//======================================
227// Impls
228//======================================
229
230impl NumericArray {
231 /// Dynamically resolve a `NumericArray` of unknown element type into a
232 /// `NumericArray<T>` with explicit element type.
233 ///
234 /// # Example
235 ///
236 /// Implement a function which returns the sum of an integral `NumericArray`
237 ///
238 /// ```no_run
239 /// use wolfram_library_link::{NumericArray, NumericArrayKind};
240 ///
241 /// fn sum(array: NumericArray) -> i64 {
242 /// match array.kind() {
243 /// NumericArrayKind::Bit8(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
244 /// NumericArrayKind::Bit16(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
245 /// NumericArrayKind::Bit32(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
246 /// NumericArrayKind::Bit64(na) => na.as_slice().into_iter().sum(),
247 /// NumericArrayKind::UBit8(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
248 /// NumericArrayKind::UBit16(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
249 /// NumericArrayKind::UBit32(na) => na.as_slice().into_iter().copied().map(i64::from).sum(),
250 /// NumericArrayKind::UBit64(na) => {
251 /// match i64::try_from(na.as_slice().into_iter().sum::<u64>()) {
252 /// Ok(sum) => sum,
253 /// Err(_) => panic!("overflows i64"),
254 /// }
255 /// },
256 /// NumericArrayKind::Real32(_)
257 /// | NumericArrayKind::Real64(_)
258 /// | NumericArrayKind::ComplexReal64(_) => panic!("bad type"),
259 /// }
260 /// }
261 /// ```
262 pub fn kind(&self) -> NumericArrayKind {
263 /// The purpose of this intermediate function is to limit the scope of the call to
264 /// transmute(). `transmute()` is a *very* unsafe function, so it seems prudent to
265 /// future-proof this code against accidental changes which alter the inferrence
266 /// of the transmute() target type.
267 unsafe fn trans<T: NumericArrayType>(array: &NumericArray) -> &NumericArray<T> {
268 std::mem::transmute(array)
269 }
270
271 unsafe {
272 use NumericArrayDataType::*;
273
274 match self.data_type() {
275 Bit8 => NumericArrayKind::Bit8(trans(self)),
276 Bit16 => NumericArrayKind::Bit16(trans(self)),
277 Bit32 => NumericArrayKind::Bit32(trans(self)),
278 Bit64 => NumericArrayKind::Bit64(trans(self)),
279
280 UBit8 => NumericArrayKind::UBit8(trans(self)),
281 UBit16 => NumericArrayKind::UBit16(trans(self)),
282 UBit32 => NumericArrayKind::UBit32(trans(self)),
283 UBit64 => NumericArrayKind::UBit64(trans(self)),
284
285 Real32 => NumericArrayKind::Real32(trans(self)),
286 Real64 => NumericArrayKind::Real64(trans(self)),
287
288 // TODO: Handle this case? Is there a 32 bit complex real typedef?
289 ComplexReal32 => unimplemented!(
290 "NumericArray::kind(): NumericArray of ComplexReal32 is not currently supported."
291 ),
292 // ComplexReal32 => NumericArrayKind::ComplexReal32(trans(self)),
293 ComplexReal64 => NumericArrayKind::ComplexReal64(trans(self)),
294 }
295 }
296 }
297
298 /// Attempt to resolve this `NumericArray` into a `&NumericArray<T>` of the specified
299 /// element type.
300 ///
301 /// If the element type of this array does not match `T`, an error will be returned.
302 ///
303 /// # Example
304 ///
305 /// Implement a function which unwraps the `&[u8]` data in a `NumericArray` of 8-bit
306 /// integers.
307 ///
308 /// ```no_run
309 /// use wolfram_library_link::NumericArray;
310 ///
311 /// fn bytes(array: &NumericArray) -> &[u8] {
312 /// let byte_array: &NumericArray<u8> = match array.try_kind::<u8>() {
313 /// Ok(array) => array,
314 /// Err(_) => panic!("expected NumericArray of UnsignedInteger8")
315 /// };
316 ///
317 /// byte_array.as_slice()
318 /// }
319 /// ```
320 pub fn try_kind<T>(&self) -> Result<&NumericArray<T>, ()>
321 where
322 T: NumericArrayType,
323 {
324 /// The purpose of this intermediate function is to limit the scope of the call to
325 /// transmute(). `transmute()` is a *very* unsafe function, so it seems prudent to
326 /// future-proof this code against accidental changes which alter the inferrence
327 /// of the transmute() target type.
328 unsafe fn trans<T: NumericArrayType>(array: &NumericArray) -> &NumericArray<T> {
329 std::mem::transmute(array)
330 }
331
332 if self.data_type() == T::TYPE {
333 return Ok(unsafe { trans(self) });
334 }
335
336 Err(())
337 }
338
339 /// Attempt to resolve this `NumericArray` into a `NumericArray<T>` of the specified
340 /// element type.
341 ///
342 /// If the element type of this array does not match `T`, the original untyped array
343 /// will be returned as the error value.
344 pub fn try_into_kind<T>(self) -> Result<NumericArray<T>, NumericArray>
345 where
346 T: NumericArrayType,
347 {
348 /// The purpose of this intermediate function is to limit the scope of the call to
349 /// transmute(). `transmute()` is a *very* unsafe function, so it seems prudent to
350 /// future-proof this code against accidental changes which alter the inferrence
351 /// of the transmute() target type.
352 unsafe fn trans<T: NumericArrayType>(array: NumericArray) -> NumericArray<T> {
353 std::mem::transmute(array)
354 }
355
356 if self.data_type() == T::TYPE {
357 return Ok(unsafe { trans(self) });
358 }
359
360 Err(self)
361 }
362}
363
364impl<T: NumericArrayType> NumericArray<T> {
365 /// Construct a new one-dimensional [`NumericArray`] from a slice.
366 ///
367 /// Use [`NumericArray::from_array()`] to construct multidimensional numeric arrays.
368 ///
369 /// # Panics
370 ///
371 /// This function will panic if [`NumericArray::try_from_array()`] returns
372 /// an error.
373 ///
374 /// # Example
375 ///
376 /// ```no_run
377 /// # use wolfram_library_link::NumericArray;
378 /// let array = NumericArray::from_slice(&[1, 2, 3, 4, 5]);
379 /// ```
380 ///
381 /// # Alternatives
382 ///
383 /// [`UninitNumericArray`] can be used to allocate a mutable numeric array,
384 /// eliminating the need for an intermediate allocation.
385 pub fn from_slice(data: &[T]) -> NumericArray<T> {
386 NumericArray::<T>::try_from_slice(data)
387 .expect("failed to create NumericArray from slice")
388 }
389
390 /// Fallible alternative to [`NumericArray::from_slice()`].
391 pub fn try_from_slice(data: &[T]) -> Result<NumericArray<T>, sys::errcode_t> {
392 let dim1 = data.len();
393
394 NumericArray::try_from_array(&[dim1], data)
395 }
396
397 /// Construct a new multidimensional [`NumericArray`] from a list of dimensions and the
398 /// flat slice of data.
399 ///
400 /// # Panics
401 ///
402 /// This function will panic if [`NumericArray::try_from_array()`] returns
403 /// an error.
404 ///
405 /// # Example
406 ///
407 /// Construct the 2x2 [`NumericArray`] `{{1, 2}, {3, 4}}` from a list of dimensions and
408 /// a flat buffer.
409 ///
410 /// ```no_run
411 /// # use wolfram_library_link::NumericArray;
412 /// let array = NumericArray::from_array(&[2, 2], &[1, 2, 3, 4]);
413 /// ```
414 pub fn from_array(dimensions: &[usize], data: &[T]) -> NumericArray<T> {
415 NumericArray::<T>::try_from_array(dimensions, data)
416 .expect("failed to create NumericArray from array")
417 }
418
419 /// Fallible alternative to [`NumericArray::from_array()`].
420 ///
421 /// This function will return an error if:
422 ///
423 /// * `dimensions` is empty
424 /// * the product of `dimensions` is 0
425 /// * `data.len()` is not equal to the product of `dimensions`
426 pub fn try_from_array(
427 dimensions: &[usize],
428 data: &[T],
429 ) -> Result<NumericArray<T>, sys::errcode_t> {
430 let uninit = UninitNumericArray::try_from_dimensions(dimensions)?;
431
432 Ok(uninit.init_from_slice(data))
433 }
434
435 /// Access the elements stored in this [`NumericArray`] as a flat buffer.
436 pub fn as_slice(&self) -> &[T] {
437 let ptr: *mut c_void = self.data_ptr();
438
439 debug_assert!(!ptr.is_null());
440
441 // Assert that `ptr` is aligned to `T`.
442 debug_assert!(ptr as usize % std::mem::size_of::<T>() == 0);
443
444 let ptr = ptr as *const T;
445
446 unsafe { std::slice::from_raw_parts(ptr, self.flattened_length()) }
447 }
448
449 /// Access the elements stored in this [`NumericArray`] as a mutable flat buffer.
450 ///
451 /// If the [`share_count()`][NumericArray::share_count] of this array is >= 1, this
452 /// function will return `None`.
453 pub fn as_slice_mut(&mut self) -> Option<&mut [T]> {
454 if self.share_count() == 0 {
455 // This is not a shared numeric array. We have unique access to it's data.
456 unsafe { Some(self.as_slice_mut_unchecked()) }
457 } else {
458 None
459 }
460 }
461
462 /// Access the elements stored in this [`NumericArray`] as a mutable flat buffer.
463 ///
464 /// # Safety
465 ///
466 /// `NumericArray` is an immutable shared data structure. There is no robust, easy way
467 /// to determine whether mutation of a `NumericArray` is safe. Prefer to use
468 /// [`UninitNumericArray`] to create and initialize a numeric array value instead of
469 /// mutating an existing `NumericArray`.
470 pub unsafe fn as_slice_mut_unchecked(&mut self) -> &mut [T] {
471 let ptr: *mut c_void = self.data_ptr();
472
473 debug_assert!(!ptr.is_null());
474
475 // Assert that `ptr` is aligned to `T`.
476 debug_assert!(ptr as usize % std::mem::size_of::<T>() == 0);
477
478 let ptr = ptr as *mut T;
479
480 std::slice::from_raw_parts_mut(ptr, self.flattened_length())
481 }
482}
483
484impl<T> NumericArray<T> {
485 /// Erase the concrete `T` data type associated with this `NumericArray`.
486 ///
487 /// Use [`NumericArray::try_into_kind()`] to convert back into a `NumericArray<T>`.
488 ///
489 /// # Example
490 ///
491 /// ```no_run
492 /// # use wolfram_library_link::NumericArray;
493 /// let array: NumericArray<i64> = NumericArray::from_slice(&[1, 2, 3]);
494 ///
495 /// let array: NumericArray = array.into_generic();
496 /// ```
497 pub fn into_generic(self) -> NumericArray {
498 let NumericArray(na, PhantomData) = self;
499
500 // Don't run Drop on `self`; ownership of this value is being given to the caller.
501 std::mem::forget(self);
502
503 NumericArray(na, PhantomData)
504 }
505
506 /// Construct a `NumericArray<T>` from a raw [`MNumericArray`][sys::MNumericArray].
507 ///
508 /// # Safety
509 ///
510 /// The following conditions must be met for safe usage of this function:
511 ///
512 /// * `array` must be a fully initialized and valid numeric array object
513 /// * `T` must either:
514 /// - be `()`, representing an array with dynamic element type, or
515 /// - `T` must satisfy [`NumericArrayType`], and the element type of `array` must
516 /// be the same as `T`.
517 // TODO: Add something about the reference count in the above list?
518 pub unsafe fn from_raw(array: sys::MNumericArray) -> NumericArray<T> {
519 NumericArray(array, PhantomData)
520 }
521
522 /// Convert this `NumericArray` into a raw [`MNumericArray`][sys::MNumericArray]
523 /// object.
524 pub unsafe fn into_raw(self) -> sys::MNumericArray {
525 let NumericArray(raw, PhantomData) = self;
526
527 // Don't run Drop on `self`; ownership of this value is being given to the caller.
528 std::mem::forget(self);
529
530 raw
531 }
532
533 /// *LibraryLink C API Documentation:* [`MNumericArray_getData`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getData.html)
534 pub fn data_ptr(&self) -> *mut c_void {
535 let NumericArray(numeric_array, _) = *self;
536
537 unsafe { data_ptr(numeric_array) }
538 }
539
540 #[allow(missing_docs)]
541 pub fn data_type(&self) -> NumericArrayDataType {
542 let value: sys::numericarray_data_t = self.data_type_raw();
543 let value: u32 = value as u32;
544
545 NumericArrayDataType::try_from(value)
546 .expect("NumericArray tensor property type is value is not a known NumericArrayDataType variant")
547 }
548
549 /// *LibraryLink C API Documentation:* [`MNumericArray_getType`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getType.html)
550 pub fn data_type_raw(&self) -> sys::numericarray_data_t {
551 let NumericArray(numeric_array, _) = *self;
552
553 unsafe { rtl::MNumericArray_getType(numeric_array) }
554 }
555
556 /// The number of elements in the underlying flat data array.
557 ///
558 /// This is the product of the dimension lengths of this [`NumericArray`].
559 ///
560 /// This is *not* the number of bytes.
561 ///
562 /// *LibraryLink C API Documentation:* [`MNumericArray_getFlattenedLength`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getFlattenedLength.html)
563 pub fn flattened_length(&self) -> usize {
564 let NumericArray(numeric_array, _) = *self;
565
566 let len = unsafe { flattened_length(numeric_array) };
567
568 // Check that the stored length matches the length computed from the dimensions.
569 debug_assert!(len == self.dimensions().iter().copied().product::<usize>());
570
571 len
572 }
573
574 /// *LibraryLink C API Documentation:* [`MNumericArray_getRank`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getRank.html)
575 pub fn rank(&self) -> usize {
576 let NumericArray(numeric_array, _) = *self;
577
578 let rank: sys::mint = unsafe { rtl::MNumericArray_getRank(numeric_array) };
579
580 let rank = usize::try_from(rank).expect("NumericArray rank overflows usize");
581
582 rank
583 }
584
585 /// Get the dimensions of this `NumericArray`.
586 ///
587 /// *LibraryLink C API Documentation:* [`MNumericArray_getDimensions`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_getDimensions.html)
588 ///
589 /// # Example
590 ///
591 /// ```no_run
592 /// # use wolfram_library_link::NumericArray;
593 /// let array = NumericArray::from_array(&[2, 2], &[1, 2, 3, 4]);
594 ///
595 /// assert_eq!(array.dimensions(), &[2, 2]);
596 /// assert_eq!(array.rank(), array.dimensions().len());
597 /// ```
598 pub fn dimensions(&self) -> &[usize] {
599 let NumericArray(numeric_array, _) = *self;
600
601 let rank = self.rank();
602
603 debug_assert!(rank != 0);
604
605 let dims: *const crate::sys::mint =
606 unsafe { rtl::MNumericArray_getDimensions(numeric_array) };
607
608 const _: () = assert!(mem::size_of::<sys::mint>() == mem::size_of::<usize>());
609 let dims: *mut usize = dims as *mut usize;
610
611 debug_assert!(!dims.is_null());
612
613 unsafe { std::slice::from_raw_parts(dims, rank) }
614 }
615
616 /// Returns the share count of this `NumericArray`.
617 ///
618 /// If this `NumericArray` is not shared, the share count is 0.
619 ///
620 /// If this `NumericArray` was passed into the current library "by reference" due to
621 /// use of the `Automatic` or `"Constant"` memory management strategy, that reference
622 /// is not reflected in the `share_count()`.
623 ///
624 /// *LibraryLink C API Documentation:* [`MNumericArray_shareCount`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_shareCount.html)
625 pub fn share_count(&self) -> usize {
626 let NumericArray(raw, PhantomData) = *self;
627
628 let count: sys::mint = unsafe { rtl::MNumericArray_shareCount(raw) };
629
630 usize::try_from(count).expect("NumericArray share count mint overflows usize")
631 }
632
633 /// Returns true if `self` and `other` are pointers to the name underlying
634 /// numeric array object.
635 pub fn ptr_eq<T2>(&self, other: &NumericArray<T2>) -> bool {
636 let NumericArray(this, PhantomData) = *self;
637 let NumericArray(other, PhantomData) = *other;
638
639 this == other
640 }
641
642 /// *LibraryLink C API Documentation:* [`MNumericArray_convertType`](https://reference.wolfram.com/language/LibraryLink/ref/callback/MNumericArray_convertType.html)
643 // TODO: When can this return an error? ClipAndCheck and the tolerance is not sufficient?
644 // TODO: Return a better error than `errcode_t`.
645 pub fn convert_to<T2: NumericArrayType>(
646 &self,
647 method: NumericArrayConvertMethod,
648 tolerance: sys::mreal,
649 ) -> Result<NumericArray<T2>, sys::errcode_t> {
650 let NumericArray(self_raw, PhantomData) = *self;
651
652 let mut new_raw: sys::MNumericArray = std::ptr::null_mut();
653
654 let err_code: sys::errcode_t = unsafe {
655 rtl::MNumericArray_convertType(
656 &mut new_raw,
657 self_raw,
658 T2::TYPE.as_raw(),
659 method.as_raw(),
660 tolerance,
661 )
662 };
663
664 if err_code != 0 || new_raw.is_null() {
665 return Err(err_code);
666 }
667
668 Ok(unsafe { NumericArray::<T2>::from_raw(new_raw) })
669 }
670}
671
672unsafe fn data_ptr(numeric_array: sys::MNumericArray) -> *mut c_void {
673 rtl::MNumericArray_getData(numeric_array)
674}
675
676unsafe fn flattened_length(numeric_array: sys::MNumericArray) -> usize {
677 let len: sys::mint = rtl::MNumericArray_getFlattenedLength(numeric_array);
678
679 let len = usize::try_from(len).expect("i64 overflows usize");
680
681 len
682}
683
684//======================================
685// UninitNumericArray
686//======================================
687
688impl<T: NumericArrayType> UninitNumericArray<T> {
689 /// Construct a new uninitialized `NumericArray` with the specified dimensions.
690 ///
691 /// # Panics
692 ///
693 /// This function will panic if [`UninitNumericArray::try_from_dimensions()`] returns
694 /// an error.
695 pub fn from_dimensions(dimensions: &[usize]) -> UninitNumericArray<T> {
696 UninitNumericArray::try_from_dimensions(dimensions)
697 .expect("failed to create UninitNumericArray from dimensions")
698 }
699
700 /// Try to construct a new uninitialized NumericArray with the specified dimensions.
701 ///
702 /// This function will return an error if:
703 ///
704 /// * `dimensions` is empty.
705 /// * the product of `dimensions` is equal to 0.
706 /// * the underlying allocation function returns `NULL`.
707 pub fn try_from_dimensions(
708 dimensions: &[usize],
709 ) -> Result<UninitNumericArray<T>, sys::errcode_t> {
710 assert!(!dimensions.is_empty());
711
712 let rank = dimensions.len();
713 debug_assert!(rank > 0);
714
715 unsafe {
716 let mut numeric_array: sys::MNumericArray = std::ptr::null_mut();
717
718 let err_code: sys::errcode_t = rtl::MNumericArray_new(
719 <T as NumericArrayType>::TYPE.as_raw(),
720 i64::try_from(rank).expect("usize overflows i64"),
721 dimensions.as_ptr() as *mut sys::mint,
722 &mut numeric_array,
723 );
724
725 if err_code != 0 || numeric_array.is_null() {
726 return Err(err_code);
727 }
728
729 Ok(UninitNumericArray(numeric_array, PhantomData))
730 }
731 }
732
733 /// # Panics
734 ///
735 /// This function will panic if `source` does not have the same length as
736 /// this array's [`as_slice_mut()`][UninitNumericArray::as_slice_mut] slice.
737 pub fn init_from_slice(mut self, source: &[T]) -> NumericArray<T> {
738 let data = self.as_slice_mut();
739
740 // Safety: copy_from_slice_uninit() unconditionally asserts that `data` and
741 // `source` have the same number of elements, so if it succeeds we're
742 // certain that every element of the NumericArray has been initialized.
743 copy_from_slice_uninit(source, data);
744
745 unsafe { self.assume_init() }
746 }
747
748 /// Mutable access to the elements of this [`UninitNumericArray`].
749 ///
750 /// This function returns a mutable slice of [`std::mem::MaybeUninit<T>`]. This is done
751 /// because it is undefined behavior in Rust to construct a `&` (or `&mut`) reference
752 /// to a value which has not been initialized. Note that it is undefined behavior even
753 /// if the reference is never read from. The `MaybeUninit` type explicitly makes the
754 /// compiler aware that the `T` value might not be initialized.
755 ///
756 /// # Example
757 ///
758 /// Construct the numeric array `{1, 2, 3, 4, 5}`.
759 ///
760 /// ```no_run
761 /// use wolfram_library_link::{NumericArray, UninitNumericArray};
762 ///
763 /// // Construct a `1x5` numeric array with elements of type `f64`.
764 /// let mut uninit = UninitNumericArray::<f64>::from_dimensions(&[5]);
765 ///
766 /// for (index, elem) in uninit.as_slice_mut().into_iter().enumerate() {
767 /// elem.write(index as f64 + 1.0);
768 /// }
769 ///
770 /// // Now that we've taken responsibility for initializing every
771 /// // element of the UninitNumericArray, we've upheld the
772 /// // invariant necessary to make a call to `assume_init()` safe.
773 /// let array: NumericArray<f64> = unsafe { uninit.assume_init() };
774 /// ```
775 ///
776 /// See [`assume_init()`][UninitNumericArray::assume_init].
777 pub fn as_slice_mut(&mut self) -> &mut [MaybeUninit<T>] {
778 let UninitNumericArray(numeric_array, PhantomData) = *self;
779
780 unsafe {
781 let len = flattened_length(numeric_array);
782
783 let ptr: *mut c_void = data_ptr(numeric_array);
784 let ptr = ptr as *mut MaybeUninit<T>;
785
786 std::slice::from_raw_parts_mut(ptr, len)
787 }
788 }
789
790 /// Assume that this NumericArray's elements have been initialized.
791 ///
792 /// Use [`as_slice_mut()`][UninitNumericArray::as_slice_mut] to initialize the values
793 /// in this array.
794 ///
795 /// # Safety
796 ///
797 /// This function must only be called once all elements of this NumericArray have
798 /// been initialized. It is undefined behavior to construct a [`NumericArray`] without
799 /// first initializing the data array.
800 pub unsafe fn assume_init(self) -> NumericArray<T> {
801 let UninitNumericArray(expr, PhantomData) = self;
802
803 // Don't run Drop on `self`; ownership of this value is being given to the caller.
804 std::mem::forget(self);
805
806 NumericArray(expr, PhantomData)
807 }
808}
809
810/// This function is modeled after after the `copy_from_slice()` method on the primitive
811/// `slice` type. This can be used to initialize an [`UninitNumericArray`] from a slice of
812/// data.
813fn copy_from_slice_uninit<T>(src: &[T], dest: &mut [MaybeUninit<T>]) {
814 assert_eq!(
815 src.len(),
816 dest.len(),
817 "destination and source slices have different lengths"
818 );
819
820 unsafe {
821 std::ptr::copy_nonoverlapping(
822 src.as_ptr(),
823 dest.as_mut_ptr() as *mut T,
824 dest.len(),
825 )
826 }
827}
828
829impl NumericArrayDataType {
830 #[allow(missing_docs)]
831 pub fn as_raw(self) -> sys::numericarray_data_t {
832 self as sys::numericarray_data_t
833 }
834
835 /// Get the string name of this type, suitable for use in
836 /// [`NumericArray`][ref/NumericArray]<code>[<i>data</i>, "<i>type</i>"]</code>.
837 ///
838 /// [ref/NumericArray]: https://reference.wolfram.com/language/ref/NumericArray.html
839 #[rustfmt::skip]
840 pub fn name(&self) -> &'static str {
841 match self {
842 NumericArrayDataType::Bit8 => "Integer8",
843 NumericArrayDataType::Bit16 => "Integer16",
844 NumericArrayDataType::Bit32 => "Integer32",
845 NumericArrayDataType::Bit64 => "Integer64",
846
847 NumericArrayDataType::UBit8 => "UnsignedInteger8",
848 NumericArrayDataType::UBit16 => "UnsignedInteger16",
849 NumericArrayDataType::UBit32 => "UnsignedInteger32",
850 NumericArrayDataType::UBit64 => "UnsignedInteger64",
851
852 NumericArrayDataType::Real32 => "Real32",
853 NumericArrayDataType::Real64 => "Real64",
854
855 NumericArrayDataType::ComplexReal32 => "ComplexReal32",
856 NumericArrayDataType::ComplexReal64 => "ComplexReal64",
857 }
858 }
859}
860
861impl NumericArrayConvertMethod {
862 #[allow(missing_docs)]
863 pub fn as_raw(self) -> sys::numericarray_convert_method_t {
864 self as sys::numericarray_convert_method_t
865 }
866}
867
868//======================================
869// Trait Impls
870//======================================
871
872impl<T> Clone for NumericArray<T> {
873 fn clone(&self) -> NumericArray<T> {
874 let NumericArray(raw, PhantomData) = *self;
875
876 unsafe {
877 let mut new: sys::MNumericArray = std::ptr::null_mut();
878 let err_code: sys::errcode_t = rtl::MNumericArray_clone(raw, &mut new);
879
880 if err_code != 0 || new.is_null() {
881 panic!("NumericArray clone failed with error code: {}", err_code);
882 }
883
884 NumericArray::<T>::from_raw(new)
885 }
886 }
887}
888
889impl<T> Drop for NumericArray<T> {
890 fn drop(&mut self) {
891 if self.share_count() > 0 {
892 // This is a "Shared" numeric array, so we should decrement the reference
893 // count.
894 let NumericArray(raw, PhantomData) = *self;
895 unsafe { rtl::MNumericArray_disown(raw) }
896 } else {
897 // This is a "Manual" numeric array (or one created within Rust), so we should
898 // free its memory directly.
899 let NumericArray(raw, PhantomData) = *self;
900 unsafe { rtl::MNumericArray_free(raw) }
901 }
902 }
903}
904
905impl<T> fmt::Debug for NumericArray<T> {
906 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
907 f.debug_struct("NumericArray")
908 .field("raw", &self.0)
909 .field("data_type", &self.data_type())
910 .finish()
911 }
912}
913
914//======================================
915// Conversion Impls
916//======================================
917
918impl TryFrom<u32> for NumericArrayDataType {
919 type Error = ();
920
921 fn try_from(value: u32) -> Result<Self, Self::Error> {
922 // debug_assert!(u32::try_from(self.tensor_property_type()).is_ok());
923
924 #[rustfmt::skip]
925 let ok = match value {
926 _ if value == BIT8_TYPE as u32 => NumericArrayDataType::Bit8,
927 _ if value == BIT16_TYPE as u32 => NumericArrayDataType::Bit16,
928 _ if value == BIT32_TYPE as u32 => NumericArrayDataType::Bit32,
929 _ if value == BIT64_TYPE as u32 => NumericArrayDataType::Bit64,
930
931 _ if value == UBIT8_TYPE as u32 => NumericArrayDataType::UBit8,
932 _ if value == UBIT16_TYPE as u32 => NumericArrayDataType::UBit16,
933 _ if value == UBIT32_TYPE as u32 => NumericArrayDataType::UBit32,
934 _ if value == UBIT64_TYPE as u32 => NumericArrayDataType::UBit64,
935
936 _ if value == REAL32_TYPE as u32 => NumericArrayDataType::Real32,
937 _ if value == REAL64_TYPE as u32 => NumericArrayDataType::Real64,
938
939 _ if value == COMPLEX_REAL32_TYPE as u32 => NumericArrayDataType::ComplexReal32,
940 _ if value == COMPLEX_REAL64_TYPE as u32 => NumericArrayDataType::ComplexReal64,
941
942 _ => return Err(()),
943 };
944
945 Ok(ok)
946 }
947}