Skip to main content

svod_tensor/
conditional.rs

1//! Conditional and selection operations for tensors.
2//!
3//! This module provides element-wise conditional operations like where, maximum,
4//! minimum, and clamp that are fundamental for many ML operations.
5
6use bon::bon;
7use snafu::ResultExt;
8use svod_ir::UOp;
9
10use crate::{Result, Tensor, error::UOpSnafu};
11
12#[bon]
13impl Tensor {
14    /// Element-wise conditional selection: `condition ? self : other`.
15    ///
16    /// For each element, returns `self[i]` if `condition[i]` is true, else `other[i]`.
17    ///
18    /// # Arguments
19    /// * `condition` - Boolean tensor (dtype should be Bool or will be treated as boolean)
20    /// * `other` - Alternative value tensor
21    ///
22    /// # Shape Requirements
23    /// All three tensors (self, condition, other) must be broadcastable to the same shape.
24    ///
25    /// # Examples
26    /// ```ignore
27    /// let x = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
28    /// let condition = &x.gt(&Tensor::from_slice(&[2.0f32]))?; // [false, false, true, true]
29    /// let zeros = Tensor::from_slice(&[0.0f32]);
30    ///
31    /// // Replace values > 2.0 with the original value, else 0
32    /// let result = x.where_(condition, &zeros)?;
33    /// // result = [0.0, 0.0, 3.0, 4.0]
34    /// ```
35    pub fn where_(&self, condition: &Tensor, other: &Tensor) -> Result<Self> {
36        use svod_ir::shape::{align_shapes_left, broadcast_shapes};
37
38        let cond_shape = condition.shape()?;
39        let self_shape = self.shape()?;
40        let other_shape = other.shape()?;
41
42        // Broadcast all three to a common shape
43        let aligned = align_shapes_left(&[cond_shape.clone(), self_shape.clone(), other_shape.clone()]);
44        let target = broadcast_shapes(&aligned).context(UOpSnafu)?;
45
46        let cond_bc = condition.broadcast_to(&target)?;
47        let self_bc = self.broadcast_to(&target)?;
48        let other_bc = other.broadcast_to(&target)?;
49
50        let result = UOp::try_where(cond_bc.uop(), self_bc.uop(), other_bc.uop()).context(UOpSnafu)?;
51        Ok(Self::new(result))
52    }
53
54    /// Element-wise maximum: `max(self, other)`.
55    ///
56    /// Returns the element-wise maximum of two tensors.
57    /// This is NOT a reduction - it returns a tensor of the same shape.
58    ///
59    /// # Shape Requirements
60    /// Both tensors must be broadcastable to the same shape.
61    ///
62    /// # Examples
63    /// ```ignore
64    /// let a = Tensor::from_slice(&[1.0f32, 5.0, 3.0]);
65    /// let b = Tensor::from_slice(&[2.0f32, 3.0, 4.0]);
66    /// let result = a.maximum(&b)?;
67    /// // result = [2.0, 5.0, 4.0]
68    /// ```
69    pub fn maximum(&self, other: &Tensor) -> Result<Self> {
70        let (lhs, rhs) = self.broadcast_for_binop(other)?;
71        let result = lhs.uop().try_max(&rhs.uop()).context(UOpSnafu)?;
72        Ok(Self::new(result))
73    }
74
75    /// Element-wise minimum: `min(self, other)`.
76    ///
77    /// Returns the element-wise minimum of two tensors.
78    /// This is NOT a reduction - it returns a tensor of the same shape.
79    ///
80    /// # Shape Requirements
81    /// Both tensors must be broadcastable to the same shape.
82    ///
83    /// # Examples
84    /// ```ignore
85    /// let a = Tensor::from_slice(&[1.0f32, 5.0, 3.0]);
86    /// let b = Tensor::from_slice(&[2.0f32, 3.0, 4.0]);
87    /// let result = a.minimum(&b)?;
88    /// // result = [1.0, 3.0, 3.0]
89    /// ```
90    pub fn minimum(&self, other: &Tensor) -> Result<Self> {
91        // Minimum is not a primitive, we implement it as: -max(-a, -b)
92        // Or equivalently: where(a < b, a, b)
93        let condition = self.try_lt(other)?;
94        self.where_(&condition, other)
95    }
96
97    /// Clamp values to a range: `max(min_val, min(self, max_val))`.
98    ///
99    /// Constrains all elements to be within [min_val, max_val].
100    ///
101    /// # Examples
102    /// ```ignore
103    /// let x = Tensor::from_slice(&[-1.0f32, 0.0, 1.0, 2.0, 3.0]);
104    /// let min = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0]);
105    /// let max = Tensor::from_slice(&[2.0f32, 2.0, 2.0, 2.0, 2.0]);
106    ///
107    /// // Clamp to [0, 2]
108    /// let result = x.clamp().min(&min).max(&max).call()?;
109    /// // result = [0.0, 0.0, 1.0, 2.0, 2.0]
110    ///
111    /// // Clamp only lower bound
112    /// let result = x.clamp().min(&min).call()?;
113    /// // result = [0.0, 0.0, 1.0, 2.0, 3.0]
114    ///
115    /// // Clamp only upper bound
116    /// let result = x.clamp().max(&max).call()?;
117    /// // result = [-1.0, 0.0, 1.0, 2.0, 2.0]
118    /// ```
119    #[builder]
120    pub fn clamp(&self, min: Option<&Tensor>, max: Option<&Tensor>) -> Result<Self> {
121        let mut result = self.clone();
122
123        if let Some(min_val) = min {
124            result = result.maximum(min_val)?;
125        }
126
127        if let Some(max_val) = max {
128            result = result.minimum(max_val)?;
129        }
130
131        Ok(result)
132    }
133
134    /// Alias for `clamp` (matches NumPy/PyTorch naming).
135    ///
136    /// # Examples
137    /// ```ignore
138    /// let x = Tensor::from_slice(&[-1.0f32, 0.0, 1.0, 2.0, 3.0]);
139    /// let min = Tensor::from_slice(&[0.0f32, 0.0, 0.0, 0.0, 0.0]);
140    /// let max = Tensor::from_slice(&[2.0f32, 2.0, 2.0, 2.0, 2.0]);
141    ///
142    /// // Clip to [0, 2]
143    /// let result = x.clip().min(&min).max(&max).call()?;
144    /// ```
145    #[builder]
146    pub fn clip(&self, min: Option<&Tensor>, max: Option<&Tensor>) -> Result<Self> {
147        self.clamp().maybe_min(min).maybe_max(max).call()
148    }
149}