rust_optimal_transport/
ndarray_logical.rs

1use std::usize;
2
3use ndarray::{prelude::*, RemoveAxis};
4use ndarray::{Axis, Data};
5use num_traits::Float;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum LogicalError {
10    #[error("axis {axis:?} is greater than array dimension {bound:?}")]
11    AxisOutOfBoundsError { axis: usize, bound: usize },
12}
13
14/// Checks if a given ndarray Axis is valid for a set of dimensions
15fn check_axis(axis: Axis, shape: &[usize]) -> Result<(), LogicalError> {
16    for dim in shape.iter() {
17        let bound = *dim - 1;
18
19        if axis.0 > bound {
20            return Err(LogicalError::AxisOutOfBoundsError {
21                axis: axis.0,
22                bound,
23            });
24        }
25    }
26
27    Ok(())
28}
29
30/// Returns True if all array elements are neither zero,
31/// infinite, subnormal, or NaN. Subnormal values are those
32/// between '0' and f32 or f64::MIN_POSITIVE
33/// Returns False otherwise
34pub fn all<S, D, A>(arr: &ArrayBase<S, D>) -> bool
35where
36    A: Float,
37    S: Data<Elem = A>,
38    D: Dimension,
39{
40    let result: bool = arr.iter().all(|ele| ele.is_normal());
41
42    result
43}
44
45/// Tests whether all array elements along a given axis evaluates to True
46///
47/// Returns an array of booleans
48///
49/// Example:
50///
51/// ```rust
52///
53/// use rust_optimal_transport as ot;
54/// use ot::ndarray_logical::axis_all;
55/// use ndarray::{prelude::*, Axis};
56///
57/// let arr = array![[f32::INFINITY, 42.], [2., 11.]];
58/// assert_eq!(axis_all(&arr, Axis(0)).unwrap(), array![false, true]);
59/// ```
60///
61pub fn axis_all<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
62where
63    A: Float,
64    S: Data<Elem = A>,
65    D: Dimension + RemoveAxis,
66{
67    check_axis(axis, arr.shape())?;
68
69    let result: Array1<bool> = arr
70        .axis_iter(axis)
71        .map(|axis_view| self::all(&axis_view))
72        .collect();
73
74    Ok(result)
75}
76
77/// Tests whether any array element evaluates to True
78/// Returns true if the number is neither zero, infinite, subnormal, or NaN.
79/// Subnormal values are those between '0' and 'f32 or f64::MIN_POSITIVE'
80/// Returns false for empty arrays
81pub fn any<S, D, A>(arr: &ArrayBase<S, D>) -> bool
82where
83    A: Float,
84    S: Data<Elem = A>,
85    D: Dimension,
86{
87    let result: bool = arr.iter().any(|ele| ele.is_normal());
88
89    result
90}
91
92/// Tests whether any array element along a given axis evaluates to True
93///
94/// Returns an array of booleans
95///
96/// Example:
97///
98/// ```rust
99///
100/// use rust_optimal_transport as ot;
101/// use ot::ndarray_logical::axis_any;
102/// use ndarray::{prelude::*, Axis};
103///
104/// let arr = array![[f32::INFINITY, f32::INFINITY], [f32::NAN, 11.]];
105/// assert_eq!(axis_any(&arr, Axis(0)).unwrap(), array![false, true]);
106/// ```
107///
108pub fn axis_any<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
109where
110    A: Float,
111    S: Data<Elem = A>,
112    D: Dimension + RemoveAxis,
113{
114    check_axis(axis, arr.shape())?;
115
116    let result: Array1<bool> = arr
117        .axis_iter(axis)
118        .map(|axis_view| self::any(&axis_view))
119        .collect();
120
121    Ok(result)
122}
123
124/// Tests element-wise for NaN elements in an array.
125/// Returns True if there are NaN, False otherwise
126pub fn is_nan<S, D, A>(arr: &ArrayBase<S, D>) -> bool
127where
128    A: Float,
129    S: Data<Elem = A>,
130    D: Dimension,
131{
132    let result: bool = arr.iter().any(|ele| ele.is_nan());
133
134    result
135}
136
137/// Tests whether any array element along a given axis is NaN
138///
139/// Returns an array of booleans
140///
141/// Example:
142///
143/// ```rust
144///
145/// use rust_optimal_transport as ot;
146/// use ot::ndarray_logical::axis_is_nan;
147/// use ndarray::{prelude::*, Axis};
148///
149/// let arr = array![[f64::NAN, 0.], [2., 11.]];
150/// assert_eq!(axis_is_nan(&arr, Axis(0)).unwrap(), array![true, false]);
151/// ```
152///
153pub fn axis_is_nan<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
154where
155    A: Float,
156    S: Data<Elem = A>,
157    D: Dimension + RemoveAxis,
158{
159    check_axis(axis, arr.shape())?;
160
161    let result: Array1<bool> = arr
162        .axis_iter(axis)
163        .map(|axis_view| self::is_nan(&axis_view))
164        .collect();
165
166    Ok(result)
167}
168
169/// Tests element-wise for inf elements in an array.
170/// Returns True if there are NaN, False otherwise
171pub fn is_inf<S, D, A>(arr: &ArrayBase<S, D>) -> bool
172where
173    A: Float,
174    S: Data<Elem = A>,
175    D: Dimension,
176{
177    let result: bool = arr.iter().any(|ele| ele.is_infinite());
178
179    result
180}
181
182/// Tests whether any array element along a given axis is inf
183///
184/// Returns an array of booleans
185///
186/// Example:
187///
188/// ```rust
189///
190/// use rust_optimal_transport as ot;
191/// use ot::ndarray_logical::axis_is_inf;
192/// use ndarray::{prelude::*, Axis};
193///
194/// let arr = array![[f64::INFINITY, 0.], [2., 11.]];
195/// assert_eq!(axis_is_inf(&arr, Axis(0)).unwrap(), array![true, false]);
196/// ```
197///
198pub fn axis_is_inf<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
199where
200    A: Float,
201    S: Data<Elem = A>,
202    D: Dimension + RemoveAxis,
203{
204    check_axis(axis, arr.shape())?;
205
206    let result: Array1<bool> = arr
207        .axis_iter(axis)
208        .map(|axis_view| self::is_inf(&axis_view))
209        .collect();
210
211    Ok(result)
212}
213
214#[cfg(test)]
215mod tests {
216
217    use super::{any, axis_all, axis_any, axis_is_inf, axis_is_nan, is_inf, is_nan};
218    use ndarray::{array, Axis};
219    use num_traits::Float;
220
221    #[test]
222    fn test_is_nan() {
223        let arr = array![1., 2., f64::NAN];
224
225        assert_eq!(is_nan(&arr), true);
226    }
227
228    #[test]
229    fn test_is_inf() {
230        let arr = array![1f32, 2f32, Float::infinity()];
231
232        assert_eq!(is_inf(&arr), true);
233    }
234
235    #[test]
236    fn test_any() {
237        let arr = array![1., 2., Float::infinity()];
238
239        assert_eq!(any(&arr), true);
240    }
241
242    #[test]
243    fn test_axis_all() {
244        let arr = array![[f32::INFINITY, 42.], [2., 11.]];
245        assert_eq!(axis_all(&arr, Axis(0)).unwrap(), array![false, true]);
246    }
247
248    #[test]
249    fn test_axis_any() {
250        let arr = array![[f32::INFINITY, f32::INFINITY], [f32::NAN, 11.]];
251        let result = match axis_any(&arr, Axis(0)) {
252            Ok(val) => val,
253            Err(error) => panic!("{:?}", error),
254        };
255
256        assert_eq!(result, array![false, true]);
257    }
258
259    #[test]
260    fn test_axis_is_nan() {
261        let arr = array![[f64::NAN, 0.], [2., 11.]];
262        assert_eq!(axis_is_nan(&arr, Axis(0)).unwrap(), array![true, false]);
263    }
264
265    #[test]
266    fn test_axis_is_inf() {
267        let arr = array![[f64::INFINITY, 0.], [2., 11.]];
268        assert_eq!(axis_is_inf(&arr, Axis(0)).unwrap(), array![true, false]);
269    }
270}