1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
/*!
 * This module wraps the raw ffi bindings (_See_ [crate::mex::raw]) exposed by each API target
 * to something a bit more ergonomic. For example, when the MEX api exposes a pointer and
 * a length, these are combined into a slice.
 */

use std::ptr::NonNull;
use num_complex::Complex;
use cfg_if::cfg_if;

use raw::{
	mxGetDimensions,
	mxGetNumberOfDimensions,
	mxGetNumberOfElements,
	mxIsComplex,
	mxIsSparse,
	mxGetClassID,
	mxGetData,

	mxClassID,

	mxClassID_mxUNKNOWN_CLASS as UNKNOWN_CLASS,
	mxClassID_mxCELL_CLASS as CELL_CLASS,
	mxClassID_mxSTRUCT_CLASS as STRUCT_CLASS,
	mxClassID_mxCHAR_CLASS as CHAR_CLASS,
	mxClassID_mxVOID_CLASS as VOID_CLASS,
	mxClassID_mxFUNCTION_CLASS as FUNCTION_CLASS,

	mxClassID_mxDOUBLE_CLASS as DOUBLE_CLASS,
	mxClassID_mxSINGLE_CLASS as SINGLE_CLASS,

	mxClassID_mxLOGICAL_CLASS as LOGICAL_CLASS,
	mxClassID_mxINT8_CLASS as I8_CLASS,
	mxClassID_mxUINT8_CLASS as U8_CLASS,
	mxClassID_mxINT16_CLASS as I16_CLASS,
	mxClassID_mxUINT16_CLASS as U16_CLASS,
	mxClassID_mxINT32_CLASS as I32_CLASS,
	mxClassID_mxUINT32_CLASS as U32_CLASS,
	mxClassID_mxINT64_CLASS as I64_CLASS,
	mxClassID_mxUINT64_CLASS as U64_CLASS,
};

#[cfg(any(not(feature = "octave"), feature = "doc"))]
use raw::{
	mxClassID_mxOBJECT_CLASS as OBJECT_CLASS,
	mxClassID_mxOPAQUE_CLASS as OPAQUE_CLASS,
};

use raw::mxArray;

pub mod raw;
pub mod alloc;
pub mod pointers;

/**
 * Possible errors which can occur when converting an mxArray into a Rust type.
 */
#[derive(Debug, Copy, Clone, PartialEq, Hash)]
#[non_exhaustive]
pub enum FromMatlabError {
	/// Returned when the type of the mxArray does not agree
	BadClass,
	/// Returned when the complexity of the mxArray does not agree with the type
	BadComplexity,
	/// Returned when the sparsity does not match the expected sparsity
	BadSparsity,
	/// Returned when the size of the mxArray does not match
	Size,
}

#[derive(Debug, Copy, Clone, PartialEq, Hash)]
#[non_exhaustive]
pub enum ToMatlabError {
	MismatchedSize,
	ComplexSizeMismatch
}

pub trait MatlabClass {
	const CLASS_ID: ClassID;
	const COMPLEXITY: bool;

	fn correct_class(mx: &mxArray) -> Result<(), FromMatlabError> {
		Self::correct_class_complexity(mx, Self::COMPLEXITY, false)
	}

	fn correct_class_complexity(mx: &mxArray, complexity: bool, sparsity: bool)
		-> Result<(), FromMatlabError>
	{
		if mx.class_id() != Self::CLASS_ID {
			return Err(FromMatlabError::BadClass);
		}
		if mx.is_complex() != complexity {
			return Err(FromMatlabError::BadComplexity);
		}
		if mx.is_sparse() != sparsity {
			return Err(FromMatlabError::BadSparsity);
		}

		Ok(())
	}
}

/**
 * Marker trait for Matlab types which allow for cheap conversions to rust types. For
 * example, an array of u32's is (mostly) the same in Matlab as it is in Rust, while a
 * struct in Matlab is markedly different from a struct in Rust. The former can be
 * converted between them with some pointer juggling, the latter cannot.
 *
 * Note that when MATLAB's interleaved complex API is used, those complex values are
 * considered a primitive; when the old API is used, they are not (conversion from a
 * complex mxArray then requires copying to fit into the Complex<T> type).
 */
pub trait MatlabPrimitive: MatlabClass + Copy {}

/**
 * Marker trait for complex MatlabPrimitives. Only implemented when Complex<T> isn't
 * itself a MatlabPrimitive (this depends on the target API).
 */
pub trait MatlabComplex: MatlabClass {}

macro_rules! impl_mlc_for {
	($t:ty, $c: expr, $id:expr) => {
		impl MatlabClass for $t {
			const CLASS_ID: ClassID = $id;
			const COMPLEXITY: bool = $c;
		}
	}
}

macro_rules! impl_mlp_for {
	($t:ty, $c:expr, $id:expr) => {
		impl_mlc_for!($t, $c, $id);
		impl MatlabPrimitive for $t {}
	}
}

#[allow(unused)]
macro_rules! impl_mlk_for {
	($t:ty, $c:expr, $id:expr) => {
		impl_mlc_for!($t, $c, $id);
		impl MatlabComplex for $t {}
	}
}

impl_mlp_for!(f64, false, ClassID::Double);
impl_mlp_for!(f32, false, ClassID::Single);
impl_mlp_for!(bool,false, ClassID::Logical);
impl_mlp_for!(i8,  false, ClassID::I8);
impl_mlp_for!(u8,  false, ClassID::U8);
impl_mlp_for!(i16, false, ClassID::I16);
impl_mlp_for!(u16, false, ClassID::U16);
impl_mlp_for!(i32, false, ClassID::I32);
impl_mlp_for!(u32, false, ClassID::U32);
impl_mlp_for!(i64, false, ClassID::I64);
impl_mlp_for!(u64, false, ClassID::U64);

cfg_if! {
	if #[cfg(feature="matlab_interleaved")] {
		impl_mlp_for!(Complex<f64>, true, ClassID::Double);
		impl_mlp_for!(Complex<f32>, true, ClassID::Single);
		impl_mlp_for!(Complex<bool>,true, ClassID::Logical);
		impl_mlp_for!(Complex<i8>,  true, ClassID::I8);
		impl_mlp_for!(Complex<u8>,  true, ClassID::U8);
		impl_mlp_for!(Complex<i16>, true, ClassID::I16);
		impl_mlp_for!(Complex<u16>, true, ClassID::U16);
		impl_mlp_for!(Complex<i32>, true, ClassID::I32);
		impl_mlp_for!(Complex<u32>, true, ClassID::U32);
		impl_mlp_for!(Complex<i64>, true, ClassID::I64);
		impl_mlp_for!(Complex<u64>, true, ClassID::U64);
	} else {
		impl_mlk_for!(Complex<f64>, true, ClassID::Double);
		impl_mlk_for!(Complex<f32>, true, ClassID::Single);
		impl_mlk_for!(Complex<bool>,true, ClassID::Logical);
		impl_mlk_for!(Complex<i8>,  true, ClassID::I8);
		impl_mlk_for!(Complex<u8>,  true, ClassID::U8);
		impl_mlk_for!(Complex<i16>, true, ClassID::I16);
		impl_mlk_for!(Complex<u16>, true, ClassID::U16);
		impl_mlk_for!(Complex<i32>, true, ClassID::I32);
		impl_mlk_for!(Complex<u32>, true, ClassID::U32);
		impl_mlk_for!(Complex<i64>, true, ClassID::I64);
		impl_mlk_for!(Complex<u64>, true, ClassID::U64);
	}
}

#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
#[repr(u32)]
pub enum ClassID {
	Unknown = UNKNOWN_CLASS,
	Cell = CELL_CLASS,
	Struct = STRUCT_CLASS,
	Char = CHAR_CLASS,
	Void = VOID_CLASS,

	Logical = LOGICAL_CLASS,

	Double = DOUBLE_CLASS,
	Single = SINGLE_CLASS,

	I8 = I8_CLASS,
	U8 = U8_CLASS,

	I16 = I16_CLASS,
	U16 = U16_CLASS,

	I32 = I32_CLASS,
	U32 = U32_CLASS,

	I64 = I64_CLASS,
	U64 = U64_CLASS,

	Function = FUNCTION_CLASS,

	#[cfg(any(not(feature = "octave"), feature = "doc"))]
	Opaque = OPAQUE_CLASS,
	#[cfg(any(not(feature = "octave"), feature = "doc"))]
	Object = OBJECT_CLASS,
}

impl From<mxClassID> for ClassID {
	fn from(cid: mxClassID) -> Self {
		match cid {
			UNKNOWN_CLASS => Self::Unknown,
			CELL_CLASS => Self::Cell,
			STRUCT_CLASS => Self::Struct,
			LOGICAL_CLASS => Self::Logical,
			CHAR_CLASS => Self::Char,
			VOID_CLASS => Self::Void,

			DOUBLE_CLASS => Self::Double,
			SINGLE_CLASS => Self::Single,

			I8_CLASS => Self::I8,
			U8_CLASS => Self::U8,
			U16_CLASS => Self::U16,
			I16_CLASS => Self::I16,
			U32_CLASS => Self::U32,
			I32_CLASS => Self::I32,
			U64_CLASS => Self::U64,
			I64_CLASS => Self::I64,

			FUNCTION_CLASS => Self::Function,

			#[cfg(any(not(feature = "octave"), feature = "doc"))]
			OBJECT_CLASS => Self::Object,
			#[cfg(any(not(feature = "octave"), feature = "doc"))]
			OPAQUE_CLASS => Self::Opaque,

			_ => panic!("Unrecognised class value")
		}
	}
}

impl From<ClassID> for mxClassID {
	fn from(cid: ClassID) -> Self {
		// Since the ClassID can only contain valid Class ID's, it can be safely
		// cast to the underlying value.
		cid as Self
	}
}

/**
 * Macro to construct a valid pointer for slice construction
 *
 * Matlab returns, for empty arrays, a null pointer. For this case, we want to construct
 * an empty slice, but cannot pass in the null pointer for that. Instead, per
 * [std::slice::from_raw_parts]' documentation, we can obtain a valid pointer through NonNull.
 */
macro_rules! data_or_dangling {
	($p:expr, $t:ty) => {{
		let ptr = { $p };
		(if ptr.is_null() {
			NonNull::dangling().as_ptr()
		} else {
			ptr
		}) as $t
	}}
}

impl mxArray {
	/// Return the sizes of the constituent dimensions of the mxArray
	pub fn dimensions(&self) -> &[usize] {
		unsafe {
			std::slice::from_raw_parts(
				mxGetDimensions(self) as *const usize,
				mxGetNumberOfDimensions(self) as usize
			)
		}
	}

	/// Return the number of elements contained in this array.
	pub fn numel(&self) -> usize {
		unsafe { mxGetNumberOfElements(self) }
	}

	pub fn class_id(&self) -> ClassID {
		unsafe { mxGetClassID(self) }.into()
	}

	pub fn is_complex(&self) -> bool {
		unsafe { mxIsComplex(self) }
	}

	pub fn is_sparse(&self) -> bool {
		unsafe { mxIsSparse(self) }
	}

	/// Return the underlying data slice of the mxArray object.
	pub fn data_slice<T>(&self) -> Result<&[T], FromMatlabError> where
		T: MatlabPrimitive
	{
		T::correct_class(self)?;

		let ptr = data_or_dangling!(unsafe { mxGetData(self) }, *const T );

		let numel = self.numel();

		Ok(unsafe { std::slice::from_raw_parts(ptr, numel) } )
	}

	#[cfg(any(feature = "matlab_separated", feature = "octave"))]
	pub fn data_slices<T>(&self) -> Result<Complex<&[T]>, FromMatlabError> where
		T: MatlabPrimitive,
		Complex<T>: MatlabComplex
	{
		use raw::mxGetImagData;

		T::correct_class_complexity(self, true, false)?;

		let re = data_or_dangling!( unsafe { mxGetData(self) }, *const T);
		let im = data_or_dangling!( unsafe { mxGetImagData(self) }, *const T);

		let numel = self.numel();

		Ok(Complex {
			re: unsafe { std::slice::from_raw_parts(re, numel) },
			im: unsafe { std::slice::from_raw_parts(im, numel) }
		})
	}
}