Skip to main content

tenflowers_core/tensor/
comparison.rs

1//! Comparison and Logical Operations
2//!
3//! This module provides tensor comparison operations (==, !=, >, <, etc.)
4//! and logical operations for boolean tensors. All operations support
5//! broadcasting and both CPU and GPU execution.
6
7use super::core::{Tensor, TensorStorage};
8use crate::Result;
9
10impl<T> Tensor<T>
11where
12    T: Clone
13        + Default
14        + scirs2_core::num_traits::Zero
15        + scirs2_core::num_traits::One
16        + Send
17        + Sync
18        + 'static
19        + bytemuck::Pod
20        + bytemuck::Zeroable,
21{
22    /// Element-wise equality comparison
23    pub fn eq(&self, other: &Self) -> Result<Tensor<bool>>
24    where
25        T: PartialEq,
26    {
27        if self.device() != other.device() {
28            return Err(crate::TensorError::device_mismatch(
29                "comparison",
30                &self.device().to_string(),
31                &other.device().to_string(),
32            ));
33        }
34
35        let broadcast_shape = self.shape().broadcast_shape(other.shape()).ok_or_else(|| {
36            crate::TensorError::ShapeMismatch {
37                operation: "broadcast".to_string(),
38                expected: self.shape().to_string(),
39                got: other.shape().to_string(),
40                context: None,
41            }
42        })?;
43
44        match (&self.storage, &other.storage) {
45            (TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
46                use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
47
48                let a_broadcast =
49                    arr_a
50                        .broadcast(IxDyn(broadcast_shape.dims()))
51                        .ok_or_else(|| {
52                            crate::TensorError::invalid_argument(
53                                "Cannot broadcast first tensor".to_string(),
54                            )
55                        })?;
56                let b_broadcast =
57                    arr_b
58                        .broadcast(IxDyn(broadcast_shape.dims()))
59                        .ok_or_else(|| {
60                            crate::TensorError::invalid_argument(
61                                "Cannot broadcast second tensor".to_string(),
62                            )
63                        })?;
64
65                let mut result = ArrayD::default(a_broadcast.raw_dim());
66                Zip::from(&mut result)
67                    .and(&a_broadcast)
68                    .and(&b_broadcast)
69                    .for_each(|r, a_val, b_val| {
70                        *r = a_val == b_val;
71                    });
72
73                Ok(Tensor::<bool>::from_array(result))
74            }
75            #[cfg(feature = "gpu")]
76            _ => {
77                // Use the high-level comparison function which handles GPU operations
78                let result = crate::ops::comparison::eq(self, other)?;
79                // Convert from u8 to bool tensor
80                match result.storage {
81                    TensorStorage::Cpu(arr) => {
82                        let bool_arr = arr.mapv(|x| x != 0);
83                        Ok(Tensor::<bool>::from_array(bool_arr))
84                    }
85                    #[cfg(feature = "gpu")]
86                    TensorStorage::Gpu(ref gpu_buf) => {
87                        // For GPU, we need to convert u8 to bool
88                        let cpu_result = gpu_buf.to_cpu()?;
89                        let arr = scirs2_core::ndarray::ArrayD::from_shape_vec(
90                            scirs2_core::ndarray::IxDyn(result.shape().dims()),
91                            cpu_result,
92                        )
93                        .map_err(|e| crate::TensorError::invalid_shape_simple(e.to_string()))?;
94                        let bool_arr = arr.mapv(|x| x != 0);
95                        Ok(Tensor::<bool>::from_array(bool_arr))
96                    }
97                }
98            }
99        }
100    }
101
102    /// Element-wise not-equal comparison
103    pub fn ne(&self, other: &Self) -> Result<Tensor<bool>>
104    where
105        T: PartialEq,
106    {
107        let eq_result = self.eq(other)?;
108        match &eq_result.storage {
109            TensorStorage::Cpu(arr) => {
110                let result = arr.mapv(|x| !x);
111                Ok(Tensor::<bool>::from_array(result))
112            }
113            #[cfg(feature = "gpu")]
114            _ => {
115                // Use the high-level comparison function which handles GPU operations
116                let result = crate::ops::comparison::ne(self, other)?;
117                // Convert from u8 to bool tensor
118                match result.storage {
119                    TensorStorage::Cpu(arr) => {
120                        let bool_arr = arr.mapv(|x| x != 0);
121                        Ok(Tensor::<bool>::from_array(bool_arr))
122                    }
123                    #[cfg(feature = "gpu")]
124                    TensorStorage::Gpu(ref gpu_buf) => {
125                        // For GPU, we need to convert u8 to bool
126                        let cpu_result = gpu_buf.to_cpu()?;
127                        let arr = scirs2_core::ndarray::ArrayD::from_shape_vec(
128                            scirs2_core::ndarray::IxDyn(result.shape().dims()),
129                            cpu_result,
130                        )
131                        .map_err(|e| crate::TensorError::invalid_shape_simple(e.to_string()))?;
132                        let bool_arr = arr.mapv(|x| x != 0);
133                        Ok(Tensor::<bool>::from_array(bool_arr))
134                    }
135                }
136            }
137        }
138    }
139
140    /// Element-wise greater-than comparison
141    pub fn gt(&self, other: &Self) -> Result<Tensor<bool>>
142    where
143        T: PartialOrd,
144    {
145        if self.device() != other.device() {
146            return Err(crate::TensorError::device_mismatch(
147                "comparison",
148                &self.device().to_string(),
149                &other.device().to_string(),
150            ));
151        }
152
153        let broadcast_shape = self.shape().broadcast_shape(other.shape()).ok_or_else(|| {
154            crate::TensorError::ShapeMismatch {
155                operation: "broadcast".to_string(),
156                expected: self.shape().to_string(),
157                got: other.shape().to_string(),
158                context: None,
159            }
160        })?;
161
162        match (&self.storage, &other.storage) {
163            (TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
164                use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
165
166                let a_broadcast =
167                    arr_a
168                        .broadcast(IxDyn(broadcast_shape.dims()))
169                        .ok_or_else(|| {
170                            crate::TensorError::invalid_argument(
171                                "Cannot broadcast first tensor".to_string(),
172                            )
173                        })?;
174                let b_broadcast =
175                    arr_b
176                        .broadcast(IxDyn(broadcast_shape.dims()))
177                        .ok_or_else(|| {
178                            crate::TensorError::invalid_argument(
179                                "Cannot broadcast second tensor".to_string(),
180                            )
181                        })?;
182
183                let mut result = ArrayD::default(a_broadcast.raw_dim());
184                Zip::from(&mut result)
185                    .and(&a_broadcast)
186                    .and(&b_broadcast)
187                    .for_each(|r, a_val, b_val| {
188                        *r = a_val > b_val;
189                    });
190
191                Ok(Tensor::<bool>::from_array(result))
192            }
193            #[cfg(feature = "gpu")]
194            _ => {
195                // Use the high-level comparison function which handles GPU operations
196                let result = crate::ops::comparison::gt(self, other)?;
197                // Convert from u8 to bool tensor
198                match result.storage {
199                    TensorStorage::Cpu(arr) => {
200                        let bool_arr = arr.mapv(|x| x != 0);
201                        Ok(Tensor::<bool>::from_array(bool_arr))
202                    }
203                    #[cfg(feature = "gpu")]
204                    TensorStorage::Gpu(ref gpu_buf) => {
205                        // For GPU, we need to convert u8 to bool
206                        let cpu_result = gpu_buf.to_cpu()?;
207                        let arr = scirs2_core::ndarray::ArrayD::from_shape_vec(
208                            scirs2_core::ndarray::IxDyn(result.shape().dims()),
209                            cpu_result,
210                        )
211                        .map_err(|e| crate::TensorError::invalid_shape_simple(e.to_string()))?;
212                        let bool_arr = arr.mapv(|x| x != 0);
213                        Ok(Tensor::<bool>::from_array(bool_arr))
214                    }
215                }
216            }
217        }
218    }
219
220    /// Element-wise greater-than-or-equal comparison
221    pub fn ge(&self, other: &Self) -> Result<Tensor<bool>>
222    where
223        T: PartialOrd,
224    {
225        if self.device() != other.device() {
226            return Err(crate::TensorError::device_mismatch(
227                "comparison",
228                &self.device().to_string(),
229                &other.device().to_string(),
230            ));
231        }
232
233        let broadcast_shape = self.shape().broadcast_shape(other.shape()).ok_or_else(|| {
234            crate::TensorError::ShapeMismatch {
235                operation: "broadcast".to_string(),
236                expected: self.shape().to_string(),
237                got: other.shape().to_string(),
238                context: None,
239            }
240        })?;
241
242        match (&self.storage, &other.storage) {
243            (TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
244                use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
245
246                let a_broadcast =
247                    arr_a
248                        .broadcast(IxDyn(broadcast_shape.dims()))
249                        .ok_or_else(|| {
250                            crate::TensorError::invalid_argument(
251                                "Cannot broadcast first tensor".to_string(),
252                            )
253                        })?;
254                let b_broadcast =
255                    arr_b
256                        .broadcast(IxDyn(broadcast_shape.dims()))
257                        .ok_or_else(|| {
258                            crate::TensorError::invalid_argument(
259                                "Cannot broadcast second tensor".to_string(),
260                            )
261                        })?;
262
263                let mut result = ArrayD::default(a_broadcast.raw_dim());
264                Zip::from(&mut result)
265                    .and(&a_broadcast)
266                    .and(&b_broadcast)
267                    .for_each(|r, a_val, b_val| {
268                        *r = a_val >= b_val;
269                    });
270
271                Ok(Tensor::<bool>::from_array(result))
272            }
273            #[cfg(feature = "gpu")]
274            _ => {
275                // Use the high-level comparison function which handles GPU operations
276                let result = crate::ops::comparison::ge(self, other)?;
277                // Convert from u8 to bool tensor
278                match result.storage {
279                    TensorStorage::Cpu(arr) => {
280                        let bool_arr = arr.mapv(|x| x != 0);
281                        Ok(Tensor::<bool>::from_array(bool_arr))
282                    }
283                    #[cfg(feature = "gpu")]
284                    TensorStorage::Gpu(ref gpu_buf) => {
285                        // For GPU, we need to convert u8 to bool
286                        let cpu_result = gpu_buf.to_cpu()?;
287                        let arr = scirs2_core::ndarray::ArrayD::from_shape_vec(
288                            scirs2_core::ndarray::IxDyn(result.shape().dims()),
289                            cpu_result,
290                        )
291                        .map_err(|e| crate::TensorError::invalid_shape_simple(e.to_string()))?;
292                        let bool_arr = arr.mapv(|x| x != 0);
293                        Ok(Tensor::<bool>::from_array(bool_arr))
294                    }
295                }
296            }
297        }
298    }
299
300    /// Element-wise less-than comparison
301    pub fn lt(&self, other: &Self) -> Result<Tensor<bool>>
302    where
303        T: PartialOrd,
304    {
305        other.gt(self)
306    }
307
308    /// Element-wise less-than-or-equal comparison
309    pub fn le(&self, other: &Self) -> Result<Tensor<bool>>
310    where
311        T: PartialOrd,
312    {
313        other.ge(self)
314    }
315}
316
317// Boolean tensor specific operations
318impl Tensor<bool> {
319    /// Cast boolean tensor to u8 tensor (false -> 0, true -> 1)
320    pub fn cast_to_u8(&self) -> Result<Tensor<u8>> {
321        match &self.storage {
322            TensorStorage::Cpu(arr) => {
323                let u8_arr = arr.mapv(|x| if x { 1u8 } else { 0u8 });
324                Ok(Tensor::<u8>::from_array(u8_arr))
325            }
326            #[cfg(feature = "gpu")]
327            TensorStorage::Gpu(_) => {
328                // For now, GPU bool->u8 casting not implemented
329                Err(crate::TensorError::unsupported_operation_simple(
330                    "GPU bool to u8 casting not yet implemented".to_string(),
331                ))
332            }
333        }
334    }
335    /// Element-wise logical AND operation
336    pub fn logical_and(&self, other: &Self) -> Result<Self> {
337        if self.device() != other.device() {
338            return Err(crate::TensorError::device_mismatch(
339                "comparison",
340                &self.device().to_string(),
341                &other.device().to_string(),
342            ));
343        }
344
345        let broadcast_shape = self.shape().broadcast_shape(other.shape()).ok_or_else(|| {
346            crate::TensorError::ShapeMismatch {
347                operation: "broadcast".to_string(),
348                expected: self.shape().to_string(),
349                got: other.shape().to_string(),
350                context: None,
351            }
352        })?;
353
354        match (&self.storage, &other.storage) {
355            (TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
356                use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
357
358                let a_broadcast =
359                    arr_a
360                        .broadcast(IxDyn(broadcast_shape.dims()))
361                        .ok_or_else(|| {
362                            crate::TensorError::invalid_argument(
363                                "Cannot broadcast first tensor".to_string(),
364                            )
365                        })?;
366                let b_broadcast =
367                    arr_b
368                        .broadcast(IxDyn(broadcast_shape.dims()))
369                        .ok_or_else(|| {
370                            crate::TensorError::invalid_argument(
371                                "Cannot broadcast second tensor".to_string(),
372                            )
373                        })?;
374
375                let mut result = ArrayD::default(a_broadcast.raw_dim());
376                Zip::from(&mut result)
377                    .and(&a_broadcast)
378                    .and(&b_broadcast)
379                    .for_each(|r, a_val, b_val| {
380                        *r = *a_val && *b_val;
381                    });
382
383                Ok(Tensor::<bool>::from_array(result))
384            }
385            #[cfg(feature = "gpu")]
386            _ => {
387                // Convert bool tensors to u8 for GPU logical operations
388                let self_u8 = self.cast_to_u8()?;
389                let other_u8 = other.cast_to_u8()?;
390
391                // Use the high-level logical function which handles GPU operations
392                let result_u8 = crate::ops::logical::logical_and(&self_u8, &other_u8)?;
393
394                // Convert result back to bool tensor
395                result_u8.cast_to_bool()
396            }
397        }
398    }
399
400    /// Element-wise logical OR operation
401    pub fn logical_or(&self, other: &Self) -> Result<Self> {
402        if self.device() != other.device() {
403            return Err(crate::TensorError::device_mismatch(
404                "comparison",
405                &self.device().to_string(),
406                &other.device().to_string(),
407            ));
408        }
409
410        let broadcast_shape = self.shape().broadcast_shape(other.shape()).ok_or_else(|| {
411            crate::TensorError::ShapeMismatch {
412                operation: "broadcast".to_string(),
413                expected: self.shape().to_string(),
414                got: other.shape().to_string(),
415                context: None,
416            }
417        })?;
418
419        match (&self.storage, &other.storage) {
420            (TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
421                use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
422
423                let a_broadcast =
424                    arr_a
425                        .broadcast(IxDyn(broadcast_shape.dims()))
426                        .ok_or_else(|| {
427                            crate::TensorError::invalid_argument(
428                                "Cannot broadcast first tensor".to_string(),
429                            )
430                        })?;
431                let b_broadcast =
432                    arr_b
433                        .broadcast(IxDyn(broadcast_shape.dims()))
434                        .ok_or_else(|| {
435                            crate::TensorError::invalid_argument(
436                                "Cannot broadcast second tensor".to_string(),
437                            )
438                        })?;
439
440                let mut result = ArrayD::default(a_broadcast.raw_dim());
441                Zip::from(&mut result)
442                    .and(&a_broadcast)
443                    .and(&b_broadcast)
444                    .for_each(|r, a_val, b_val| {
445                        *r = *a_val || *b_val;
446                    });
447
448                Ok(Tensor::<bool>::from_array(result))
449            }
450            #[cfg(feature = "gpu")]
451            _ => {
452                // Convert bool tensors to u8 for GPU logical operations
453                let self_u8 = self.cast_to_u8()?;
454                let other_u8 = other.cast_to_u8()?;
455
456                // Use the high-level logical function which handles GPU operations
457                let result_u8 = crate::ops::logical::logical_or(&self_u8, &other_u8)?;
458
459                // Convert result back to bool tensor
460                result_u8.cast_to_bool()
461            }
462        }
463    }
464
465    /// Element-wise logical NOT operation
466    pub fn logical_not(&self) -> Result<Self> {
467        match &self.storage {
468            TensorStorage::Cpu(arr) => {
469                let result = arr.mapv(|x| !x);
470                Ok(Tensor::<bool>::from_array(result))
471            }
472            #[cfg(feature = "gpu")]
473            _ => {
474                // Convert bool tensor to u8 for GPU logical operations
475                let self_u8 = self.cast_to_u8()?;
476
477                // Use the high-level logical function which handles GPU operations
478                let result_u8 = crate::ops::logical::logical_not(&self_u8)?;
479
480                // Convert result back to bool tensor
481                result_u8.cast_to_bool()
482            }
483        }
484    }
485
486    /// Element-wise logical XOR operation
487    pub fn logical_xor(&self, other: &Self) -> Result<Self> {
488        if self.device() != other.device() {
489            return Err(crate::TensorError::device_mismatch(
490                "comparison",
491                &self.device().to_string(),
492                &other.device().to_string(),
493            ));
494        }
495
496        let broadcast_shape = self.shape().broadcast_shape(other.shape()).ok_or_else(|| {
497            crate::TensorError::ShapeMismatch {
498                operation: "broadcast".to_string(),
499                expected: self.shape().to_string(),
500                got: other.shape().to_string(),
501                context: None,
502            }
503        })?;
504
505        match (&self.storage, &other.storage) {
506            (TensorStorage::Cpu(arr_a), TensorStorage::Cpu(arr_b)) => {
507                use scirs2_core::ndarray::{ArrayD, IxDyn, Zip};
508
509                let a_broadcast =
510                    arr_a
511                        .broadcast(IxDyn(broadcast_shape.dims()))
512                        .ok_or_else(|| {
513                            crate::TensorError::invalid_argument(
514                                "Cannot broadcast first tensor".to_string(),
515                            )
516                        })?;
517                let b_broadcast =
518                    arr_b
519                        .broadcast(IxDyn(broadcast_shape.dims()))
520                        .ok_or_else(|| {
521                            crate::TensorError::invalid_argument(
522                                "Cannot broadcast second tensor".to_string(),
523                            )
524                        })?;
525
526                let mut result = ArrayD::default(a_broadcast.raw_dim());
527                Zip::from(&mut result)
528                    .and(&a_broadcast)
529                    .and(&b_broadcast)
530                    .for_each(|r, a_val, b_val| {
531                        *r = *a_val ^ *b_val;
532                    });
533
534                Ok(Tensor::<bool>::from_array(result))
535            }
536            #[cfg(feature = "gpu")]
537            _ => {
538                // Convert bool tensors to u8 for GPU logical operations
539                let self_u8 = self.cast_to_u8()?;
540                let other_u8 = other.cast_to_u8()?;
541
542                // Use the high-level logical function which handles GPU operations
543                let result_u8 = crate::ops::logical::logical_xor(&self_u8, &other_u8)?;
544
545                // Convert result back to bool tensor
546                result_u8.cast_to_bool()
547            }
548        }
549    }
550
551    /// Reduce tensor using logical AND along specified axes
552    pub fn all(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self> {
553        crate::ops::reduction::all(self, axes, keepdims)
554    }
555
556    /// Reduce tensor using logical OR along specified axes
557    pub fn any(&self, axes: Option<&[i32]>, keepdims: bool) -> Result<Self> {
558        crate::ops::reduction::any(self, axes, keepdims)
559    }
560}
561
562// U8 tensor specific operations
563impl Tensor<u8> {
564    /// Cast u8 tensor to boolean tensor (0 -> false, non-zero -> true)
565    pub fn cast_to_bool(&self) -> Result<Tensor<bool>> {
566        match &self.storage {
567            TensorStorage::Cpu(arr) => {
568                let bool_arr = arr.mapv(|x| x != 0);
569                Ok(Tensor::<bool>::from_array(bool_arr))
570            }
571            #[cfg(feature = "gpu")]
572            TensorStorage::Gpu(_) => {
573                // For now, GPU u8->bool casting not implemented
574                Err(crate::TensorError::unsupported_operation_simple(
575                    "GPU u8 to bool casting not yet implemented".to_string(),
576                ))
577            }
578        }
579    }
580}