1use std::usize;
2
3use ndarray::{prelude::*, RemoveAxis};
4use ndarray::{Axis, Data};
5use num_traits::Float;
6use thiserror::Error;
7
8#[derive(Error, Debug)]
9pub enum LogicalError {
10 #[error("axis {axis:?} is greater than array dimension {bound:?}")]
11 AxisOutOfBoundsError { axis: usize, bound: usize },
12}
13
14fn check_axis(axis: Axis, shape: &[usize]) -> Result<(), LogicalError> {
16 for dim in shape.iter() {
17 let bound = *dim - 1;
18
19 if axis.0 > bound {
20 return Err(LogicalError::AxisOutOfBoundsError {
21 axis: axis.0,
22 bound,
23 });
24 }
25 }
26
27 Ok(())
28}
29
30pub fn all<S, D, A>(arr: &ArrayBase<S, D>) -> bool
35where
36 A: Float,
37 S: Data<Elem = A>,
38 D: Dimension,
39{
40 let result: bool = arr.iter().all(|ele| ele.is_normal());
41
42 result
43}
44
45pub fn axis_all<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
62where
63 A: Float,
64 S: Data<Elem = A>,
65 D: Dimension + RemoveAxis,
66{
67 check_axis(axis, arr.shape())?;
68
69 let result: Array1<bool> = arr
70 .axis_iter(axis)
71 .map(|axis_view| self::all(&axis_view))
72 .collect();
73
74 Ok(result)
75}
76
77pub fn any<S, D, A>(arr: &ArrayBase<S, D>) -> bool
82where
83 A: Float,
84 S: Data<Elem = A>,
85 D: Dimension,
86{
87 let result: bool = arr.iter().any(|ele| ele.is_normal());
88
89 result
90}
91
92pub fn axis_any<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
109where
110 A: Float,
111 S: Data<Elem = A>,
112 D: Dimension + RemoveAxis,
113{
114 check_axis(axis, arr.shape())?;
115
116 let result: Array1<bool> = arr
117 .axis_iter(axis)
118 .map(|axis_view| self::any(&axis_view))
119 .collect();
120
121 Ok(result)
122}
123
124pub fn is_nan<S, D, A>(arr: &ArrayBase<S, D>) -> bool
127where
128 A: Float,
129 S: Data<Elem = A>,
130 D: Dimension,
131{
132 let result: bool = arr.iter().any(|ele| ele.is_nan());
133
134 result
135}
136
137pub fn axis_is_nan<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
154where
155 A: Float,
156 S: Data<Elem = A>,
157 D: Dimension + RemoveAxis,
158{
159 check_axis(axis, arr.shape())?;
160
161 let result: Array1<bool> = arr
162 .axis_iter(axis)
163 .map(|axis_view| self::is_nan(&axis_view))
164 .collect();
165
166 Ok(result)
167}
168
169pub fn is_inf<S, D, A>(arr: &ArrayBase<S, D>) -> bool
172where
173 A: Float,
174 S: Data<Elem = A>,
175 D: Dimension,
176{
177 let result: bool = arr.iter().any(|ele| ele.is_infinite());
178
179 result
180}
181
182pub fn axis_is_inf<S, D, A>(arr: &ArrayBase<S, D>, axis: Axis) -> Result<Array1<bool>, LogicalError>
199where
200 A: Float,
201 S: Data<Elem = A>,
202 D: Dimension + RemoveAxis,
203{
204 check_axis(axis, arr.shape())?;
205
206 let result: Array1<bool> = arr
207 .axis_iter(axis)
208 .map(|axis_view| self::is_inf(&axis_view))
209 .collect();
210
211 Ok(result)
212}
213
214#[cfg(test)]
215mod tests {
216
217 use super::{any, axis_all, axis_any, axis_is_inf, axis_is_nan, is_inf, is_nan};
218 use ndarray::{array, Axis};
219 use num_traits::Float;
220
221 #[test]
222 fn test_is_nan() {
223 let arr = array![1., 2., f64::NAN];
224
225 assert_eq!(is_nan(&arr), true);
226 }
227
228 #[test]
229 fn test_is_inf() {
230 let arr = array![1f32, 2f32, Float::infinity()];
231
232 assert_eq!(is_inf(&arr), true);
233 }
234
235 #[test]
236 fn test_any() {
237 let arr = array![1., 2., Float::infinity()];
238
239 assert_eq!(any(&arr), true);
240 }
241
242 #[test]
243 fn test_axis_all() {
244 let arr = array![[f32::INFINITY, 42.], [2., 11.]];
245 assert_eq!(axis_all(&arr, Axis(0)).unwrap(), array![false, true]);
246 }
247
248 #[test]
249 fn test_axis_any() {
250 let arr = array![[f32::INFINITY, f32::INFINITY], [f32::NAN, 11.]];
251 let result = match axis_any(&arr, Axis(0)) {
252 Ok(val) => val,
253 Err(error) => panic!("{:?}", error),
254 };
255
256 assert_eq!(result, array![false, true]);
257 }
258
259 #[test]
260 fn test_axis_is_nan() {
261 let arr = array![[f64::NAN, 0.], [2., 11.]];
262 assert_eq!(axis_is_nan(&arr, Axis(0)).unwrap(), array![true, false]);
263 }
264
265 #[test]
266 fn test_axis_is_inf() {
267 let arr = array![[f64::INFINITY, 0.], [2., 11.]];
268 assert_eq!(axis_is_inf(&arr, Axis(0)).unwrap(), array![true, false]);
269 }
270}